Added ability to train on submodules in expt2

This commit is contained in:
2024-05-30 17:34:53 +01:00
parent 52fb30a9ed
commit d2ec5c0c1a

View File

@@ -4,9 +4,11 @@ from torch import nn
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self, return_module_y=False):
super().__init__() super().__init__()
self.return_module_y = return_module_y
self.x0_encoder = nn.TransformerEncoderLayer(7, 7) self.x0_encoder = nn.TransformerEncoderLayer(7, 7)
self.x1_encoder = nn.TransformerEncoderLayer(10, 10) self.x1_encoder = nn.TransformerEncoderLayer(10, 10)
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
@@ -36,11 +38,14 @@ class Model(nn.Module):
def forward(self, x): def forward(self, x):
x0, x1 = x x0, x1 = x
x0 = self.encode_x0(x0) y0 = self.encode_x0(x0)
x1 = self.encode_x1(x1) y1 = self.encode_x1(x1)
x = torch.cat([x0, x1], dim=1) x = torch.cat([y0, y1], dim=1)
y = self.ff(x) y = self.ff(x)
return y if self.return_module_y:
return y, y0, y1
else:
return y
# This is just a quick, lazy way to ensure all models are trained on the same dataset # This is just a quick, lazy way to ensure all models are trained on the same dataset