-
Notifications
You must be signed in to change notification settings - Fork 1
Add auto loss logging #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5cbdf05
d891e75
7e0eae7
b124f32
2ae005c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| """ | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Optional, Union, Tuple, Dict, Sequence, List | ||
| from typing import Optional, Union, Tuple, Dict, Sequence, List, Any | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -79,6 +79,7 @@ def __post_init__(self): | |
| def __call__( | ||
| self, | ||
| train: bool, | ||
| epoch: Optional[int] = None, | ||
| context: Optional[Context] = None, | ||
| **inputs: torch.Tensor | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
|
|
@@ -88,6 +89,7 @@ def __call__( | |
| skipped during validation. | ||
|
|
||
| :param train: Whether the model is in training mode. | ||
| :param epoch: Optional epoch number to determine the weight from the schedule. | ||
| :param context: Optional Context object containing tensors. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haven't dug into the code, but consider renaming |
||
| :param inputs: Keyword arguments containing all necessary inputs for the | ||
| loss computation. | ||
|
|
@@ -117,6 +119,21 @@ def __call__( | |
|
|
||
| return raw, raw * _scalar_from_ctx(self.weight, inputs) | ||
|
|
||
| def get_config(self) -> Dict[str, Any]: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea |
||
| """ | ||
| Get the configuration of the LossItem for logging or checkpointing. | ||
| """ | ||
| return { | ||
| 'module': self.module.__class__.__name__, | ||
| 'args': self.args, | ||
| 'key': self.key, | ||
| 'weight': self.weight, | ||
| 'enabled': self.enabled, | ||
| 'compute_at_val': self.compute_at_val, | ||
| 'device': str(self.device) | ||
| } | ||
|
|
||
|
|
||
| @dataclass | ||
| class LossGroup: | ||
| """ | ||
|
|
@@ -137,13 +154,15 @@ def item_names(self) -> List[Optional[str]]: | |
| def __call__( | ||
| self, | ||
| train: bool, | ||
| epoch: Optional[int] = None, | ||
| context: Optional[Context] = None, | ||
| **inputs: torch.Tensor | ||
| ) -> Tuple[torch.Tensor, Dict[str, Scalar]]: | ||
| """ | ||
| Compute the total loss and individual loss values. | ||
|
|
||
| :param train: Whether the model is in training mode. | ||
| :param epoch: Optional epoch number to determine the weight from the schedule. | ||
| :param context: Optional Context object containing tensors. | ||
| :input inputs: Keyword arguments containing all necessary inputs for the | ||
| loss computations. | ||
|
|
@@ -156,8 +175,20 @@ def __call__( | |
| logs: Dict[str, float] = {} | ||
|
|
||
| for item in self.items: | ||
| raw, weighted = item(train, context=context, **inputs) | ||
| raw, weighted = item( | ||
| train, | ||
| epoch=epoch, | ||
| context=context, | ||
| **inputs | ||
| ) | ||
| logs[item.key] = raw.item() # type: ignore | ||
| total += weighted | ||
|
|
||
| return total, logs | ||
|
|
||
| def get_config(self) -> List[Dict[str, Any]]: | ||
| """ | ||
| Get the configuration of the LossGroup for logging or checkpointing. | ||
| """ | ||
|
|
||
| return [item.get_config() for item in self.items] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider renaming self.items here to something more specific (such as the grouping - like training_loss_objects) |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,7 +105,11 @@ def train_step( | |
| targets=targets | ||
| ) | ||
|
|
||
| weighted_total, logs = self._loss_group(train=True, context=ctx) | ||
| weighted_total, logs = self._loss_group( | ||
| train=True, | ||
| epoch=self.epoch, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it would be worth adding this to |
||
| context=ctx | ||
| ) | ||
| weighted_total.backward() | ||
| self._forward_group.step() | ||
|
|
||
|
|
@@ -133,12 +137,20 @@ def evaluate_step( | |
| targets=targets | ||
| ) | ||
|
|
||
| _, logs = self._loss_group(train=False, context=ctx) | ||
| _, logs = self._loss_group( | ||
| train=False, | ||
| epoch=self.epoch, | ||
| context=ctx | ||
| ) | ||
|
|
||
| for _, metric in self.metrics.items(): | ||
| metric.update(*ctx.as_metric_args(), validation=True) | ||
|
|
||
| return logs | ||
|
|
||
| @property | ||
| def loss_groups(self) -> Dict[str, LossGroup]: | ||
| return {'main': self._loss_group} | ||
|
|
||
| def save_model( | ||
| self, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -222,6 +222,8 @@ def on_train_start(self): | |
| except Exception as e: | ||
| print(f"Fail to log model config as artifact: {e}") | ||
|
|
||
| self._log_loss_groups_config_and_tags() | ||
|
|
||
|
|
||
| for callback in self.callbacks: | ||
| # TODO consider if we want hasattr checks | ||
|
|
@@ -502,6 +504,92 @@ def _save_model_weights( | |
| artifact_path=artifact_path | ||
| ) | ||
|
|
||
| def _get_loss_groups(self) -> Dict[str, Any]: | ||
| """ | ||
| Discover loss groups attached to the bound trainer. | ||
| """ | ||
|
Comment on lines
+508
to
+510
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good minimal comment |
||
|
|
||
| if self.trainer is None: | ||
| return {} | ||
|
|
||
| loss_groups: Dict[str, Any] = {} | ||
|
|
||
| explicit_groups = getattr(self.trainer, 'loss_groups', None) | ||
| if isinstance(explicit_groups, dict): | ||
| for group_name, group in explicit_groups.items(): | ||
| if hasattr(group, 'get_config'): | ||
| loss_groups[str(group_name)] = group | ||
|
|
||
| fallback_attrs = { | ||
| 'main': '_loss_group', | ||
| 'generator': '_generator_loss_group', | ||
| 'discriminator': '_discriminator_loss_group' | ||
| } | ||
| for group_name, attr in fallback_attrs.items(): | ||
| if group_name in loss_groups: | ||
| continue | ||
| group = getattr(self.trainer, attr, None) | ||
| if group is not None and hasattr(group, 'get_config'): | ||
| loss_groups[group_name] = group | ||
|
|
||
| return loss_groups | ||
|
|
||
| def _log_loss_groups_config_and_tags(self) -> None: | ||
| """ | ||
| Log loss item names and weights as flat tags and full loss group | ||
| configuration as config artifacts. | ||
| """ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this is mentioned somewhere else, but after reading this comment, I'm not sure what flat tags and full loss groups mean. Consider expanding on this here or mentioning it somewhere else if you haven't already |
||
|
|
||
| loss_groups = self._get_loss_groups() | ||
| if not loss_groups: | ||
| return None | ||
|
|
||
| for group_name, group in loss_groups.items(): | ||
| try: | ||
| group_config = group.get_config() | ||
| except Exception as e: | ||
| print( | ||
| f"Could not get loss group config for logging " | ||
| f"({group_name}): {e}" | ||
| ) | ||
| continue | ||
|
|
||
| if not isinstance(group_config, list): | ||
| continue | ||
|
|
||
| for idx, item_cfg in enumerate(group_config): | ||
| if not isinstance(item_cfg, dict): | ||
| continue | ||
|
|
||
| if 'key' in item_cfg and item_cfg['key'] is not None: | ||
| mlflow.set_tag( | ||
| f"loss.{group_name}.{idx}.name", | ||
| str(item_cfg['key']) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider naming this to something more specific if possible |
||
| ) | ||
|
|
||
| if 'weight' in item_cfg and item_cfg['weight'] is not None: | ||
| mlflow.set_tag( | ||
| f"loss.{group_name}.{idx}.weight", | ||
| str(item_cfg['weight']) | ||
| ) | ||
|
|
||
| try: | ||
| self.log_config( | ||
| tag=f"loss_group_{group_name}", | ||
| config={ | ||
| 'group_name': group_name, | ||
| 'items': group_config | ||
| }, | ||
| stage=None | ||
| ) | ||
| except Exception as e: | ||
| print( | ||
| f"Fail to log loss group config as artifact " | ||
| f"({group_name}): {e}" | ||
| ) | ||
|
|
||
| return None | ||
|
|
||
| def log_config( | ||
| self, | ||
| tag: str, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some weight schedulers also use per batch, so you could consider adding this as well, or maybe it could be something like:
:param progress: Optional training progress index (step or epoch) used to determine the scheduled weight.