mirror of
https://github.com/Cian-H/symbolic_nn_tests.git
synced 2025-12-22 22:22:01 +00:00
Added ability to train on submodules in expt2
This commit is contained in:
@@ -4,9 +4,11 @@ from torch import nn
|
|||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, return_module_y=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.return_module_y = return_module_y
|
||||||
|
|
||||||
self.x0_encoder = nn.TransformerEncoderLayer(7, 7)
|
self.x0_encoder = nn.TransformerEncoderLayer(7, 7)
|
||||||
self.x1_encoder = nn.TransformerEncoderLayer(10, 10)
|
self.x1_encoder = nn.TransformerEncoderLayer(10, 10)
|
||||||
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
self.encode_x0 = self.create_xval_encoding_fn(self.x0_encoder)
|
||||||
@@ -36,11 +38,14 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x0, x1 = x
|
x0, x1 = x
|
||||||
x0 = self.encode_x0(x0)
|
y0 = self.encode_x0(x0)
|
||||||
x1 = self.encode_x1(x1)
|
y1 = self.encode_x1(x1)
|
||||||
x = torch.cat([x0, x1], dim=1)
|
x = torch.cat([y0, y1], dim=1)
|
||||||
y = self.ff(x)
|
y = self.ff(x)
|
||||||
return y
|
if self.return_module_y:
|
||||||
|
return y, y0, y1
|
||||||
|
else:
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
# This is just a quick, lazy way to ensure all models are trained on the same dataset
|
||||||
|
|||||||
Reference in New Issue
Block a user