Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `UW` (Uncertainty Weighting) from [Multi-Task Learning Using Uncertainty to Weigh Losses
for Scene Geometry and
Semantics](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf),
a `Scalarizer` that combines the values using learned per-task uncertainties. It is the first
stateful, trainable scalarizer: its log-variances are an `nn.Parameter` that must be passed to
the optimizer.
- Added `STCH` from [Smooth Tchebycheff Scalarization for Multi-Objective
Optimization](https://openreview.net/pdf?id=m4dO5L6eCp), a `Scalarizer` that combines the input
tensor of values into a smooth approximation of their (weighted, shifted) maximum.
Expand Down
13 changes: 11 additions & 2 deletions docs/source/docs/scalarization/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@ scalarization
.. automodule:: torchjd.scalarization
:no-members:

Abstract base class
-------------------
Abstract base classes
---------------------

.. autoclass:: torchjd.scalarization.Scalarizer
:members: __call__

.. py:class:: torchjd.scalarization.Stateful

Mixin adding a reset method.

.. py:method:: reset()

Resets the internal state.


.. toctree::
:hidden:
Expand All @@ -21,3 +29,4 @@ Abstract base class
random.rst
stch.rst
sum.rst
uw.rst
7 changes: 7 additions & 0 deletions docs/source/docs/scalarization/uw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:hide-toc:

UW
==

.. autoclass:: torchjd.scalarization.UW
:members: __call__, reset
9 changes: 9 additions & 0 deletions src/torchjd/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any


class Stateful(ABC):
"""Mixin adding a reset method."""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""


class _WithOptionalDeps:
"""
Mixin that raises :class:`ImportError` at instantiation time if required optional dependencies
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_cr_mogm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class CRMOGMWeighting(Weighting[_T], Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Stateful`
:class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another
:class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it
produces with an exponential moving average (EMA) across calls. This is the weight-smoothing
Expand Down Expand Up @@ -120,7 +120,7 @@ def alpha(self, value: float) -> None:
def reset(self) -> None:
r"""
Clears the EMA state so the next forward restarts from the initial state. Also resets the
wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`.
wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`.
"""

if isinstance(self.weighting, Stateful):
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Non-differentiable: weights are modified in-place during the gradient correction loop.
class GradVacWeighting(_GramianWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.GradVac`.

Expand Down Expand Up @@ -130,7 +130,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:

class GradVac(GramianWeightedAggregator, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Stateful`
:class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
Expand Down
9 changes: 1 addition & 8 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any

import torch
from torch import nn


class Stateful(ABC):
"""Mixin adding a reset method."""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""
from torchjd._mixins import Stateful as Stateful


class _NonDifferentiable(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_modo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class MoDoWeighting(_MatrixWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] from `Three-Way
Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance
<https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024).
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class _NashMTLWeighting(_WithOptionalDeps, _MatrixWeighting, Stateful, _NonDiffe
_REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.Matrix`] that
extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining
Game <https://arxiv.org/pdf/2202.01017.pdf>`_.
Expand Down Expand Up @@ -206,7 +206,7 @@ def reset(self) -> None:

class NashMTL(WeightedAggregator, Stateful, _NonDifferentiable):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Stateful`
:class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of
`Multi-Task Learning as a Bargaining Game <https://arxiv.org/pdf/2202.01017.pdf>`_.

Expand Down
15 changes: 14 additions & 1 deletion src/torchjd/scalarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,25 @@
tensor(2.)
"""

from torchjd._mixins import Stateful

from ._constant import Constant
from ._geometric_mean import GeometricMean
from ._mean import Mean
from ._random import Random
from ._scalarizer_base import Scalarizer
from ._stch import STCH
from ._sum import Sum
from ._uw import UW

__all__ = ["Constant", "GeometricMean", "Mean", "Random", "Scalarizer", "STCH", "Sum"]
__all__ = [
"Constant",
"GeometricMean",
"Mean",
"Random",
"Scalarizer",
"STCH",
"Stateful",
"Sum",
"UW",
]
81 changes: 81 additions & 0 deletions src/torchjd/scalarization/_uw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from collections.abc import Sequence

import torch
from torch import Tensor, nn

from torchjd._mixins import Stateful

from ._scalarizer_base import Scalarizer


class UW(Scalarizer, Stateful):
r"""
:class:`~torchjd.scalarization.Stateful`
:class:`~torchjd.scalarization.Scalarizer` that combines the input tensor of values using
Comment thread
ppraneth marked this conversation as resolved.
learned per-task uncertainties. ``UW`` is short for Uncertainty Weighting, the method proposed
in `Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
<https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf>`_.

Each value :math:`L_i` is assigned a learnable log-variance :math:`s_i`, and the values are
combined as

.. math::
\sum_i \left( \frac{1}{2} e^{-s_i} L_i + \frac{1}{2} s_i \right)

where:

- :math:`L_i` is the :math:`i`-th value (typically the loss of task :math:`i`);
- :math:`s_i = \log \sigma_i^2` is the learnable log-variance of task :math:`i`.

Following the paper, the log-variance :math:`s_i` is learned rather than the variance
:math:`\sigma_i^2` directly: this is numerically more stable (the combination never divides by
zero) and keeps :math:`s_i` unconstrained, since :math:`e^{-s_i}` is always positive. The
:math:`s_i` are stored as an ``nn.Parameter``, so the parameters of this scalarizer must be
passed to the optimizer to be learned jointly with the model.

:param shape: The shape of the values to scalarize, used to create one log-variance per value.
An ``int`` ``n`` is interpreted as the shape ``(n,)``.

The following example shows train a model with Uncertainty Weighting, as described in the paper.

>>> import torch
>>> from torch.nn import Linear
>>>
>>> from torchjd.scalarization import UW
>>>
>>> model = Linear(3, 2)
>>> scalarizer = UW(2) # Move to the right device with e.g. UW(2).to(device="cuda")
>>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1)
>>>
>>> features = torch.randn(8, 3)
>>> # Compute some dummy losses just for the sake of the example
>>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension.
Comment thread
ppraneth marked this conversation as resolved.
>>> loss = scalarizer(losses)
>>> loss.backward()
>>> optimizer.step()

.. note::
The log-variances are initialized to ``0`` (i.e. :math:`\sigma_i^2 = 1`), which gives
uniform weights at the start of training. The paper reports that the result is robust to
this initialization. (`LibMTL <https://github.com/median-research-group/LibMTL>`_
initializes them to ``-0.5`` instead.)
"""

def __init__(self, shape: int | Sequence[int]) -> None:
super().__init__()
self.log_var = nn.Parameter(torch.zeros(shape))

def forward(self, values: Tensor, /) -> Tensor:
if values.shape != self.log_var.shape:
raise ValueError(
f"Parameter `values` should have shape {tuple(self.log_var.shape)} (matching the "
f"shape of the log-variances). Found `values.shape = {tuple(values.shape)}`.",
)
return (0.5 * torch.exp(-self.log_var) * values + 0.5 * self.log_var).sum()

def reset(self) -> None:
with torch.no_grad():
self.log_var.zero_()

def __repr__(self) -> str:
return f"{self.__class__.__name__}(shape={tuple(self.log_var.shape)})"
100 changes: 100 additions & 0 deletions tests/unit/scalarization/test_uw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from contextlib import nullcontext as does_not_raise

import torch
from pytest import mark, raises
from settings import DEVICE, DTYPE
from torch import Tensor
from utils.contexts import ExceptionContext
from utils.tensors import ones_, tensor_, zeros_

from torchjd.scalarization import UW

from ._asserts import assert_grad_flow, assert_returns_scalar
from ._inputs import all_inputs


def _uw(shape: int | tuple[int, ...]) -> UW:
"""Builds a `UW` whose log-variances live on the test device and dtype."""
return UW(shape).to(device=DEVICE, dtype=DTYPE)


def test_value() -> None:
# With log-variances initialized to 0, the result is 0.5 * sum(values).
values = tensor_([1.0, 2.0, 4.0])
torch.testing.assert_close(_uw((3,))(values), tensor_(3.5))


def test_int_shape_matches_tuple_shape() -> None:
values = tensor_([1.0, 2.0, 4.0])
assert UW(3).log_var.shape == (3,)
torch.testing.assert_close(_uw(3)(values), _uw((3,))(values))


@mark.parametrize("values", all_inputs)
def test_expected_structure(values: Tensor) -> None:
assert_returns_scalar(_uw(tuple(values.shape)), values)


@mark.parametrize("values", all_inputs)
def test_grad_flow(values: Tensor) -> None:
assert_grad_flow(_uw(tuple(values.shape)), values)


@mark.parametrize("values", all_inputs)
def test_grad_flows_to_log_var(values: Tensor) -> None:
scalarizer = _uw(tuple(values.shape))
scalarizer(values).backward()
assert scalarizer.log_var.grad is not None
assert scalarizer.log_var.grad.isfinite().all()


@mark.parametrize(
["param_shape", "values_shape", "expectation"],
[
((5,), (5,), does_not_raise()),
((3, 4), (3, 4), does_not_raise()),
((), (), does_not_raise()),
((5,), (4,), raises(ValueError)),
((5,), (5, 1), raises(ValueError)),
((3, 4), (4, 3), raises(ValueError)),
],
)
def test_shape_check(
param_shape: tuple[int, ...],
values_shape: tuple[int, ...],
expectation: ExceptionContext,
) -> None:
scalarizer = _uw(param_shape)
values = ones_(values_shape)
with expectation:
_ = scalarizer(values)


def test_reset_restores_initial_log_var() -> None:
scalarizer = _uw((3,))
with torch.no_grad():
scalarizer.log_var.add_(1.0)
scalarizer.reset()
torch.testing.assert_close(scalarizer.log_var.detach(), zeros_((3,)))


def test_does_not_raise_on_negative_input() -> None:
# Unlike GeometricMean, UW has no positivity precondition.
values = tensor_([-1.0, -2.0, 3.0])
assert_returns_scalar(_uw((3,)), values)


def test_is_trainable() -> None:
scalarizer = _uw((2,))
optimizer = torch.optim.SGD(scalarizer.parameters(), lr=0.1)
values = tensor_([2.0, 5.0])
optimizer.zero_grad()
scalarizer(values).backward()
optimizer.step()
assert not torch.equal(scalarizer.log_var.detach(), zeros_((2,)))


def test_representations() -> None:
assert repr(UW(3)) == "UW(shape=(3,))"
assert repr(UW((2, 3))) == "UW(shape=(2, 3))"
assert str(UW(3)) == "UW"
Loading