From d2ec5c0c1a1ac85b4bbb8f03b04c281779af998c Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Thu, 30 May 2024 17:34:53 +0100 Subject: [PATCH] Added ability to train on submodules in expt2 --- symbolic_nn_tests/experiment2/model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index a436e8f..890a1cd 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -4,9 +4,11 @@ from torch import nn class Model(nn.Module): - def __init__(self): + def __init__(self, return_module_y=False): super().__init__() + self.return_module_y = return_module_y + self.x0_encoder = nn.TransformerEncoderLayer(7, 7) self.x1_encoder = nn.TransformerEncoderLayer(10, 10) self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder) @@ -36,11 +38,14 @@ class Model(nn.Module): def forward(self, x): x0, x1 = x - x0 = self.encode_x0(x0) - x1 = self.encode_x1(x1) - x = torch.cat([x0, x1], dim=1) + y0 = self.encode_x0(x0) + y1 = self.encode_x1(x1) + x = torch.cat([y0, y1], dim=1) y = self.ff(x) - return y + if self.return_module_y: + return y, y0, y1 + else: + return y # This is just a quick, lazy way to ensure all models are trained on the same dataset