diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index ea5e856..95fd4b1 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -3,3 +3,4 @@ ================ Training Loss (Sun Feb 23 16:00:07 2025) ================ ================ Training Loss (Sun Feb 23 16:02:40 2025) ================ ================ Training Loss (Sun Feb 23 16:05:19 2025) ================ +================ Training Loss (Sun Feb 23 16:06:44 2025) ================ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 340cd49..eb910dd 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index d56f29a..0000f0f 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -406,8 +406,8 @@ class RomaUnsbModel(BaseModel): self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True) print(self.mutil_real_A0_tokens) - self.mutil_real_A0_tokens = torch.tensor(self.mutil_real_A0_tokens, device=self.device) - self.mutil_real_A1_tokens = torch.tensor(self.mutil_real_A1_tokens, device=self.device) + self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=1).to(self.device) + self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=1).to(self.device) # 执行一次SB模块