diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index a436e8f..890a1cd 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -4,9 +4,11 @@ from torch import nn class Model(nn.Module): - def __init__(self): + def __init__(self, return_module_y=False): super().__init__() + self.return_module_y = return_module_y + self.x0_encoder = nn.TransformerEncoderLayer(7, 7) self.x1_encoder = nn.TransformerEncoderLayer(10, 10) self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) @@ -36,11 +38,14 @@ class Model(nn.Module): def forward(self, x): x0, x1 = x - x0 = self.encode_x0(x0) - x1 = self.encode_x1(x1) - x = torch.cat([x0, x1], dim=1) + y0 = self.encode_x0(x0) + y1 = self.encode_x1(x1) + x = torch.cat([y0, y1], dim=1) y = self.ff(x) - return y + if self.return_module_y: + return y, y0, y1 + else: + return y # This is just a quick, lazy way to ensure all models are trained on the same dataset