Skip to content

feat(aggregation): Add STCH#719

Merged
ValerianRey merged 5 commits into
SimplexLab:mainfrom
ppraneth:scalarization-3
Jun 6, 2026
Merged

feat(aggregation): Add STCH#719
ValerianRey merged 5 commits into
SimplexLab:mainfrom
ppraneth:scalarization-3

Conversation

@ppraneth
Copy link
Copy Markdown
Contributor

@ppraneth ppraneth commented May 30, 2026

New torchjd.scalarization.STCH, the smooth Tchebycheff scalarization from Smooth Tchebycheff Scalarization for Multi-Objective Optimization .

It returns a differentiable approximation of the weighted, shifted maximum of the values:

$$g_\mu^{\text{STCH}}(v \mid \lambda) = \mu \log \sum_{i=1}^{m} \exp\left(\frac{\lambda_i (f_i - z_i^*)}{\mu}\right)$$

where, following the paper's notation:

  • $f_i$ is the $i$-th input value (the $i$-th objective)
  • $m$ is the number of objectives (the number of elements of the input)
  • $\lambda_i$ is the preference weight for objective $i$ (the weights parameter)
  • $z_i^*$ is the $i$-th component of the ideal point (the reference parameter)
  • $\mu$ is the smoothing parameter (the mu parameter)

As $\mu \to 0$ this recovers the classical (non-differentiable) Tchebycheff $\max_i \lambda_i (f_i - z_i^*)$; larger $\mu$ gives a smoother approximation.

Design decisions

Confirmed with the maintainer before implementation:

  • Paper version, not the reference code. Both the official Xi-L/STCH impl and LibMTL's copy of it are stateful (epoch-based warmup + a running nadir estimate, applied to log(loss / nadir)), which diverges from eq. 9. We use the clean stateless formula the paper proves its theory on.
  • Name: STCH (the paper's acronym).
  • mu is required, no default. The paper tests $\mu \in {0.01, 0.1, 0.5, 1}$ and reports no single value is best across problems, so we force a conscious choice.
  • Full API: mu (required), weights (optional, default uniform on the simplex), reference (optional, default none).

One thing worth a look: the 1/m default and mu

The default weights is uniform on the simplex (1/m), matching the paper. A consequence: the exponent becomes $(f_i - z_i^*) / (m\mu)$, so the effective smoothing temperature is $m\mu$, not $\mu$. In practice this means the meaning of mu is coupled to the number of objectives — more objectives gives a smoother result for the same mu. This is faithful to the paper's simplex convention, but if you'd rather decouple mu from m, the alternative is an all-ones default. Happy to switch if you prefer.

Implementation

def forward(self, values: Tensor, /) -> Tensor:
    # shape checks for weights / reference omitted here
    weights = self.weights if self.weights is not None else torch.full_like(values, 1.0 / values.numel())
    shifted = values if self.reference is None else values - self.reference
    exponents = weights * shifted / self.mu
    return self.mu * torch.logsumexp(exponents.flatten(), dim=-1)
  • Uses torch.logsumexp, so it's numerically stable without manual max-subtraction.
  • No sign precondition on the input (unlike GeometricMean); negative values are fine.
  • mu <= 0 raises in __init__. weights / reference shape mismatches raise at call time (same pattern as Constant).
  • The simplex constraint on user-supplied weights is not enforced (permissive, consistent with Constant).

Files

File Purpose
src/torchjd/scalarization/_stch.py New class
src/torchjd/scalarization/__init__.py Export
docs/source/docs/scalarization/stch.rst Doc page
docs/source/docs/scalarization/index.rst Toctree entry
tests/unit/scalarization/test_stch.py Unit tests
CHANGELOG.md [Unreleased] entry

Test plan

  • uv run pytest tests/unit/scalarization/test_stch.py -W error -v
  • uv run pytest tests/unit -W error (full regression)
  • uv run ruff check && uv run ruff format --check
  • uv run ty check

ppraneth added 2 commits May 29, 2026 20:45
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from a team as a code owner May 30, 2026 02:41
@ppraneth ppraneth changed the title add STCH feat(scalarization): add STCH May 30, 2026
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels May 30, 2026
@github-actions github-actions Bot changed the title feat(scalarization): add STCH feat(aggregation): Add STCH May 30, 2026
@github-actions github-actions Bot changed the title feat(scalarization): add STCH feat(aggregation): Add STCH May 30, 2026
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for the PR, LGTM. We'll wait for @ValerianRey 's review still.

Comment thread src/torchjd/scalarization/_stch.py Outdated
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ValerianRey ValerianRey mentioned this pull request Jun 2, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks a lot for the PR. It already looks really good.

I agree with the choice of going for what's described in the paper: the official implementation (and likewise the LibMTL implementation) seem to contain many important differences with respect to the paper, which make them not representative at all of what's described in the paper.

On the other hand, what's described in the paper seems to be reasonable, and it's also what's implemented in LibMoon (see the link I just added in the issue).

I would add that there are two extra things mentioned in the paper: appendices B1 and B2.

In B1, they describe a way to stabilize the method. This would maybe be a bit hard to implement for a user, out of STCH. So maybe we could add a boolean parameter stabilize to opt-in for that. It doesn't add state or anything, so it should be easy to do. @PierreQuinton @ppraneth what do you think?

Appendix B2 is about normalization of the function values, which can easily be handled by the user or by a torchjd.normalization that we'll add in the near future. So I think it shouldn't be added in this PR.

I'll make an in-depth review of the PR later.

@ppraneth
Copy link
Copy Markdown
Contributor Author

ppraneth commented Jun 5, 2026

Thanks a lot for the PR. It already looks really good.

I agree with the choice of going for what's described in the paper: the official implementation (and likewise the LibMTL implementation) seem to contain many important differences with respect to the paper, which make them not representative at all of what's described in the paper.

On the other hand, what's described in the paper seems to be reasonable, and it's also what's implemented in LibMoon (see the link I just added in the issue).

I would add that there are two extra things mentioned in the paper: appendices B1 and B2.

In B1, they describe a way to stabilize the method. This would maybe be a bit hard to implement for a user, out of STCH. So maybe we could add a boolean parameter stabilize to opt-in for that. It doesn't add state or anything, so it should be easy to do. @PierreQuinton @ppraneth what do you think?

Appendix B2 is about normalization of the function values, which can easily be handled by the user or by a torchjd.normalization that we'll add in the near future. So I think it shouldn't be added in this PR.

I'll make an in-depth review of the PR later.

Thanks for the review @ValerianRey ! I looked more carefully at B.1.

The stabilization there is specifically the max-subtraction trick applied before dividing by mu. Our current code does:

exponents = weights * shifted / self.mu
return self.mu * torch.logsumexp(exponents.flatten(), dim=-1)

torch.logsumexp does apply max-subtraction internally, so the exp inside it never overflows. But the division shifted / self.mu happens before logsumexp sees anything, so if the values are large and mu is small, that intermediate tensor can overflow to inf before the stabilization has a chance to help.

B.1's approach of centering y_i before dividing by mu avoids this entirely, since the values fed to exp are always non-positive regardless of input scale.

In float32 this only bites at extreme magnitudes, but in float16/bfloat16 mixed precision the /mu step can overflow at values of just a few hundred, so it's a real footgun in low precision. And since the fix is value-preserving (same output, same gradient) and essentially free, I'd rather just make it the default than gate it behind a flag:

y = weights * shifted
max_y = y.max()
exponents = (y - max_y) / self.mu
return self.mu * torch.logsumexp(exponents.flatten(), dim=-1) + max_y

This is what B.1 describes minus the dropped constant: it returns exactly the same STCH value and gradient as now, but never overflows in the /mu step. The + max_y cancels the centering in the gradient, so it stays correct without needing to detach anything.

Happy to update the PR with this if you agree it's the right default.

@PierreQuinton
Copy link
Copy Markdown
Contributor

PierreQuinton commented Jun 5, 2026

The result is the same? Or is it more than just numerical considerations?

If the former, then I would go for it as it looks more stable.

@ppraneth ppraneth requested a review from PierreQuinton June 5, 2026 16:19
@ppraneth
Copy link
Copy Markdown
Contributor Author

ppraneth commented Jun 5, 2026

@PierreQuinton Purely numerical. The value and the gradient are mathematically identical to what we have now, it just avoids the overflow in the /mu step. No behavior change otherwise.

@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks for the review @ValerianRey ! I looked more carefully at B.1.

The stabilization there is specifically the max-subtraction trick applied before dividing by mu. Our current code does:

exponents = weights * shifted / self.mu
return self.mu * torch.logsumexp(exponents.flatten(), dim=-1)

torch.logsumexp does apply max-subtraction internally, so the exp inside it never overflows. But the division shifted / self.mu happens before logsumexp sees anything, so if the values are large and mu is small, that intermediate tensor can overflow to inf before the stabilization has a chance to help.

B.1's approach of centering y_i before dividing by mu avoids this entirely, since the values fed to exp are always non-positive regardless of input scale.

In float32 this only bites at extreme magnitudes, but in float16/bfloat16 mixed precision the /mu step can overflow at values of just a few hundred, so it's a real footgun in low precision. And since the fix is value-preserving (same output, same gradient) and essentially free, I'd rather just make it the default than gate it behind a flag:

y = weights * shifted
max_y = y.max()
exponents = (y - max_y) / self.mu
return self.mu * torch.logsumexp(exponents.flatten(), dim=-1) + max_y

This is what B.1 describes minus the dropped constant: it returns exactly the same STCH value and gradient as now, but never overflows in the /mu step. The + max_y cancels the centering in the gradient, so it stays correct without needing to detach anything.

Happy to update the PR with this if you agree it's the right default.

Thanks for explaining. I understand better now, and I agree with your suggested change. Please go ahead with it.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just did the thorough review, and I have nothing to report, this is super clean.

We can merge after you add the extra stabilization trick from appendix B1.

Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth
Copy link
Copy Markdown
Contributor Author

ppraneth commented Jun 6, 2026

@ValerianRey I have made the changes

@ValerianRey ValerianRey merged commit f095b0f into SimplexLab:main Jun 6, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants