From 5465118b8c294a2b99d44ffd5c93b11564725739 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Fri, 8 Mar 2024 10:23:28 +0800 Subject: [PATCH] [Fix] update build loss api (#3587) ## Motivation Use `MODELS.build` instead of `build_loss` ## Modification Please briefly describe what modification is made in this PR. --- mmseg/models/decode_heads/decode_head.py | 6 +++--- mmseg/models/decode_heads/enc_head.py | 3 +-- mmseg/models/decode_heads/vpd_depth_head.py | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 179d871fd1..fd53afe22d 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -8,9 +8,9 @@ from mmengine.model import BaseModule from torch import Tensor +from mmseg.registry import MODELS from mmseg.structures import build_pixel_sampler from mmseg.utils import ConfigType, SampleList -from ..builder import build_loss from ..losses import accuracy from ..utils import resize @@ -140,11 +140,11 @@ def __init__(self, self.threshold = threshold if isinstance(loss_decode, dict): - self.loss_decode = build_loss(loss_decode) + self.loss_decode = MODELS.build(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: - self.loss_decode.append(build_loss(loss)) + self.loss_decode.append(MODELS.build(loss)) else: raise TypeError(f'loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}') diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py index ef48fb6995..2bba73b301 100644 --- a/mmseg/models/decode_heads/enc_head.py +++ b/mmseg/models/decode_heads/enc_head.py @@ -9,7 +9,6 @@ from mmseg.registry import MODELS from mmseg.utils import ConfigType, SampleList -from ..builder import build_loss from ..utils import Encoding, resize from .decode_head import BaseDecodeHead @@ -128,7 +127,7 @@ def __init__(self, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) if self.use_se_loss: - self.loss_se_decode = build_loss(loss_se_decode) + self.loss_se_decode = MODELS.build(loss_se_decode) self.se_layer = nn.Linear(self.channels, self.num_classes) def forward(self, inputs): diff --git a/mmseg/models/decode_heads/vpd_depth_head.py b/mmseg/models/decode_heads/vpd_depth_head.py index 0c54c2da1b..65bdfbd8d9 100644 --- a/mmseg/models/decode_heads/vpd_depth_head.py +++ b/mmseg/models/decode_heads/vpd_depth_head.py @@ -10,7 +10,6 @@ from mmseg.registry import MODELS from mmseg.utils import SampleList -from ..builder import build_loss from ..utils import resize from .decode_head import BaseDecodeHead @@ -184,11 +183,11 @@ def __init__( # build loss if isinstance(loss_decode, dict): - self.loss_decode = build_loss(loss_decode) + self.loss_decode = MODELS.build(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: - self.loss_decode.append(build_loss(loss)) + self.loss_decode.append(MODELS.build(loss)) else: raise TypeError(f'loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}')