|
板凳
楼主 |
发表于 2023-7-14 14:34:43
|
只看该作者
模型的backbone是resnet50,后面接的GeM
这是一开始的GeM:
def gem(x, p=3, eps=1e-6):
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1.0 / p)
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps
def forward(self, x):
return gem(x, p=self.p, eps=self.eps)[:,:,0,0]
def __repr__(self):
return (
self.__class__.__name__
+ f"(p={self.p.data.tolist()[0]:.4f}, eps={str(self.eps)})"
)
为了避免reciprocal这个算子改成了下面这样:
def gem(x, p=1/3, eps=1e-6):
return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(p)
class GeM(nn.Module):
def __init__(self, p=1/3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps
def forward(self, x):
return gem(x, p=self.p, eps=self.eps)[:,:,0,0]
def __repr__(self):
return (
self.__class__.__name__
+ f"(p={self.p.data.tolist()[0]:.4f}, eps={str(self.eps)})"
)
大佬帮忙看看是不是修改GeM出现了问题? |
|