Added ability to train on submodules in expt2

This commit is contained in:
2024-05-30 17:34:53 +01:00
parent 52fb30a9ed
commit d2ec5c0c1a

View File

@@ -4,9 +4,11 @@ from torch import nn
class Model(nn.Module):
def __init__(self):
def __init__(self, return_module_y=False):
super().__init__()
self.return_module_y = return_module_y
self.x0_encoder = nn.TransformerEncoderLayer(7, 7)
self.x1_encoder = nn.TransformerEncoderLayer(10, 10)
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
@@ -36,10 +38,13 @@ class Model(nn.Module):
def forward(self, x):
x0, x1 = x
x0 = self.encode_x0(x0)
x1 = self.encode_x1(x1)
x = torch.cat([x0, x1], dim=1)
y0 = self.encode_x0(x0)
y1 = self.encode_x1(x1)
x = torch.cat([y0, y1], dim=1)
y = self.ff(x)
if self.return_module_y:
return y, y0, y1
else:
return y