改了D返回scores, features

This commit is contained in:
bishe 2025-03-07 19:25:25 +08:00
parent 14ba81514f
commit 997fdd3770

View File

@ -1413,12 +1413,11 @@ class MLPDiscriminator(nn.Module):
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
features = self.linear1(x) # 中间特征,即 D_real 或 D_fake
x = self.activation(features)
x = self.dropout(x)
x = self.linear2(x)
return self.dropout(x)
scores = self.linear2(x) # 最终分数,即 real_scores 或 fake_scores
return scores, features
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""