diff --git a/poetry.lock b/poetry.lock index 4fcce23..43c7075 100644 --- a/poetry.lock +++ b/poetry.lock @@ -633,6 +633,20 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "cheap-repr" +version = "0.5.1" +description = "Better version of repr/reprlib for short, cheap string representations." +optional = false +python-versions = "*" +files = [ + {file = "cheap_repr-0.5.1-py2.py3-none-any.whl", hash = "sha256:30096998aeb49367a4a153988d7a99dce9dc59bbdd4b19740da6b4f3f97cf2ff"}, + {file = "cheap_repr-0.5.1.tar.gz", hash = "sha256:31ec63b9d8394aa23d746c8376c8307f75f9fca0b983566b8bcf13cc661fe6dd"}, +] + +[package.extras] +tests = ["Django", "Django (<2)", "Django (<3)", "chainmap", "numpy (>=1.16.3)", "numpy (>=1.16.3,<1.17)", "numpy (>=1.16.3,<1.19)", "pandas (>=0.24.2)", "pandas (>=0.24.2,<0.25)", "pandas (>=0.24.2,<0.26)", "pytest"] + [[package]] name = "click" version = "8.1.7" @@ -1745,6 +1759,17 @@ files = [ [package.dependencies] ansicon = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "json5" version = "0.9.25" @@ -3581,6 +3606,23 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyaml" +version = "24.4.0" +description = "PyYAML-based module to produce a bit more pretty and readable YAML-serialized data" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyaml-24.4.0-py3-none-any.whl", hash = "sha256:acc2b39c55cb0cbe4f694a6d3886f89ad3d2a5b3efcece526202f8de9a6b27de"}, + {file = "pyaml-24.4.0.tar.gz", hash = "sha256:0e483d9289010e747a325dc43171bcc39d6562dd1dd4719e8cc7e7c96c99fce6"}, +] + +[package.dependencies] +PyYAML = "*" + +[package.extras] +anchors = ["unidecode"] + [[package]] name = "pyarrow" version = "16.1.0" @@ -4442,6 +4484,117 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-learn" +version = "1.5.0" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12e40ac48555e6b551f0a0a5743cc94cc5a765c9513fe708e01f0aa001da2801"}, + {file = "scikit_learn-1.5.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f405c4dae288f5f6553b10c4ac9ea7754d5180ec11e296464adb5d6ac68b6ef5"}, + {file = "scikit_learn-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df8ccabbf583315f13160a4bb06037bde99ea7d8211a69787a6b7c5d4ebb6fc3"}, + {file = "scikit_learn-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c75ea812cd83b1385bbfa94ae971f0d80adb338a9523f6bbcb5e0b0381151d4"}, + {file = "scikit_learn-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:a90c5da84829a0b9b4bf00daf62754b2be741e66b5946911f5bdfaa869fcedd6"}, + {file = "scikit_learn-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2a65af2d8a6cce4e163a7951a4cfbfa7fceb2d5c013a4b593686c7f16445cf9d"}, + {file = "scikit_learn-1.5.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:4c0c56c3005f2ec1db3787aeaabefa96256580678cec783986836fc64f8ff622"}, + {file = "scikit_learn-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f77547165c00625551e5c250cefa3f03f2fc92c5e18668abd90bfc4be2e0bff"}, + {file = "scikit_learn-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:118a8d229a41158c9f90093e46b3737120a165181a1b58c03461447aa4657415"}, + {file = "scikit_learn-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:a03b09f9f7f09ffe8c5efffe2e9de1196c696d811be6798ad5eddf323c6f4d40"}, + {file = "scikit_learn-1.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:460806030c666addee1f074788b3978329a5bfdc9b7d63e7aad3f6d45c67a210"}, + {file = "scikit_learn-1.5.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:1b94d6440603752b27842eda97f6395f570941857456c606eb1d638efdb38184"}, + {file = "scikit_learn-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d82c2e573f0f2f2f0be897e7a31fcf4e73869247738ab8c3ce7245549af58ab8"}, + {file = "scikit_learn-1.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3a10e1d9e834e84d05e468ec501a356226338778769317ee0b84043c0d8fb06"}, + {file = "scikit_learn-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:855fc5fa8ed9e4f08291203af3d3e5fbdc4737bd617a371559aaa2088166046e"}, + {file = "scikit_learn-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:40fb7d4a9a2db07e6e0cae4dc7bdbb8fada17043bac24104d8165e10e4cff1a2"}, + {file = "scikit_learn-1.5.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:47132440050b1c5beb95f8ba0b2402bbd9057ce96ec0ba86f2f445dd4f34df67"}, + {file = "scikit_learn-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:174beb56e3e881c90424e21f576fa69c4ffcf5174632a79ab4461c4c960315ac"}, + {file = "scikit_learn-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261fe334ca48f09ed64b8fae13f9b46cc43ac5f580c4a605cbb0a517456c8f71"}, + {file = "scikit_learn-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:057b991ac64b3e75c9c04b5f9395eaf19a6179244c089afdebaad98264bff37c"}, + {file = "scikit_learn-1.5.0.tar.gz", hash = "sha256:789e3db01c750ed6d496fa2db7d50637857b451e57bcae863bff707c1247bef7"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.15.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.15.0)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scikit-optimize" +version = "0.10.2" +description = "Sequential model-based optimization toolbox." +optional = false +python-versions = "*" +files = [ + {file = "scikit_optimize-0.10.2-py2.py3-none-any.whl", hash = "sha256:45bc7e879b086133984721f2f6735a86c085073f6c481c2ec665b5c67b44d723"}, + {file = "scikit_optimize-0.10.2.tar.gz", hash = "sha256:00a3d91bf9015e292b6e7aaefe7e6cb95e8d25ce19adafd2cd88849e1a0b0da0"}, +] + +[package.dependencies] +joblib = ">=0.11" +numpy = ">=1.20.3" +packaging = ">=21.3" +pyaml = ">=16.9" +scikit-learn = ">=1.0.0" +scipy = ">=1.1.0" + +[package.extras] +dev = ["flake8", "pandas", "pytest", "pytest-cov", "pytest-xdist"] +doc = ["memory-profiler", "numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-gallery (>=0.6)"] +plots = ["matplotlib (>=2.0.0)"] + +[[package]] +name = "scipy" +version = "1.13.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, + {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, + {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, + {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, + {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, + {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, + {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, + {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, + {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, + {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, + {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, +] + +[package.dependencies] +numpy = ">=1.22.4,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "send2trash" version = "1.8.3" @@ -4679,6 +4832,27 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "snoop" +version = "0.4.3" +description = "Powerful debugging tools for Python" +optional = false +python-versions = "*" +files = [ + {file = "snoop-0.4.3-py2.py3-none-any.whl", hash = "sha256:b7418581889ff78b29d9dc5ad4625c4c475c74755fb5cba82c693c6e32afadc0"}, + {file = "snoop-0.4.3.tar.gz", hash = "sha256:2e0930bb19ff0dbdaa6f5933f88e89ed5984210ea9f9de0e1d8231fa5c1c1f25"}, +] + +[package.dependencies] +asttokens = "*" +cheap-repr = ">=0.4.0" +executing = "*" +pygments = "*" +six = "*" + +[package.extras] +tests = ["Django", "birdseye", "littleutils", "numpy (>=1.16.5)", "pandas (>=0.24.2)", "pprintpp", "prettyprinter", "pytest", "pytest-order", "pytest-order (<=0.11.0)"] + [[package]] name = "soupsieve" version = "2.5" @@ -4903,6 +5077,17 @@ files = [ {file = "text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8"}, ] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "timg" version = "1.1.6" @@ -5515,4 +5700,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "26f68613584a0e1f2a3a8d501e0eaa39fc18197621d8a595bfac6fb4d4455a27" +content-hash = "3f561f238f435bd9e6432fdaec496c83e0321c6fc2b6a1effa223ded6f464700" diff --git a/pyproject.toml b/pyproject.toml index e5966e7..7fbd951 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ hvplot = "^0.10.0" pyarrow = "^16.1.0" loguru = "^0.7.2" plotly = "^5.22.0" +snoop = "^0.4.3" +scikit-optimize = "^0.10.2" [build-system] diff --git a/symbolic_nn_tests/experiment2/__init__.py b/symbolic_nn_tests/experiment2/__init__.py index fda05b1..9bf85df 100644 --- a/symbolic_nn_tests/experiment2/__init__.py +++ b/symbolic_nn_tests/experiment2/__init__.py @@ -1,7 +1,15 @@ LEARNING_RATE = 10e-5 -def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True): +def test( + train_loss, + val_loss, + test_loss, + version, + tensorboard=True, + wandb=True, + semantic_trainer=False, +): from .model import main as test_model logger = [] @@ -37,6 +45,7 @@ def test(train_loss, val_loss, test_loss, version, tensorboard=True, wandb=True) val_loss=val_loss, test_loss=test_loss, lr=LEARNING_RATE, + semantic_trainer=semantic_trainer, ) if wandb: @@ -70,10 +79,13 @@ def run(tensorboard: bool = True, wandb: bool = True): wandb_logger = wandb test( - train_loss=semantic_loss.PositiveSlopeLinearLoss(wandb_logger, version), + train_loss=semantic_loss.PositiveSlopeLinearLoss( + wandb_logger, version, log_freq=50 + ), val_loss=unpacking_smooth_l1_loss, test_loss=unpacking_smooth_l1_loss, version=version, tensorboard=tensorboard, wandb=wandb_logger, + semantic_trainer=True, ) diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index 215119c..305c0ca 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -63,11 +63,15 @@ def main( val_loss=unpacking_smooth_l1_loss, test_loss=unpacking_smooth_l1_loss, logger=None, + semantic_trainer=False, **kwargs, ): import lightning as L - from symbolic_nn_tests.train import TrainingWrapper + if semantic_trainer: + from .train import TrainingWrapper + else: + from symbolic_nn_tests.train import TrainingWrapper if logger is None: from lightning.pytorch.loggers import TensorBoardLogger diff --git a/symbolic_nn_tests/experiment2/semantic_loss.py b/symbolic_nn_tests/experiment2/semantic_loss.py index b42f869..6e8369d 100644 --- a/symbolic_nn_tests/experiment2/semantic_loss.py +++ b/symbolic_nn_tests/experiment2/semantic_loss.py @@ -1,4 +1,5 @@ from symbolic_nn_tests.experiment2.math import linear_fit, linear_residuals, sech +from random import random from torch import nn import torch @@ -18,16 +19,17 @@ import torch # proportionality. -class PositiveSlopeLinearLoss: +class PositiveSlopeLinearLoss(nn.Module): def __init__(self, wandb_logger=None, version="", device="cuda", log_freq=50): - self.a = nn.Parameter(data=torch.randn(1), requires_grad=True).to(device) + super().__init__() + self.params = [random()] self.wandb_logger = wandb_logger self.version = version self.device = device self.log_freq = log_freq self.steps_since_log = 0 - def __call__(self, out, y): + def forward(self, out, y): x, y_pred = out x0, x1 = x @@ -62,10 +64,10 @@ class PositiveSlopeLinearLoss: # We also need to calculate a penalty that incentivizes a positive slope. For this, im using relu # to scale the slope as it will penalise negative slopes without just creating a reward hack for # maximizing slope. - slope_penalty = (nn.functional.relu(self.a * (-m)) + 1).mean() + slope_penalty = (nn.functional.relu(self.params[0] * (-m)) + 1).mean() - if self.wandb_logger and (self.steps_since_log >= 50): - self.wandb_logger.log_metrics({f"{self.version}-a": self.a}) + if self.wandb_logger and (self.steps_since_log >= self.log_freq): + self.wandb_logger.log_metrics({f"{self.version}-a": self.params}) self.steps_since_log = 0 else: self.steps_since_log += 1 diff --git a/symbolic_nn_tests/experiment2/train.py b/symbolic_nn_tests/experiment2/train.py index 77e3f74..2f5d150 100644 --- a/symbolic_nn_tests/experiment2/train.py +++ b/symbolic_nn_tests/experiment2/train.py @@ -1,20 +1,29 @@ -from symbolic_nn_tests.train import TrainingWrapper +from symbolic_nn_tests.train import TrainingWrapper as _TrainingWrapper +import torch +from skopt import Optimizer +from skopt.learning import RandomForestRegressor -class SemanticModuleTrainingWrapper(TrainingWrapper): - def __init__(self, model, *args, loss_func0, loss_func1, loss_agg, **kwargs): - assert len(args) == 0 +class TrainingWrapper(_TrainingWrapper): + def __init__(self, *args, loss_rate_target=-10, **kwargs): + super().__init__(*args, **kwargs) + self.loss_optimizer = Optimizer( + [(0.0, 1000.0)], + base_estimator=RandomForestRegressor( + n_jobs=-1, + ), + n_initial_points=10, + model_queue_size=10, + acq_func="gp_hedge", + ) + self.loss_rate_target = torch.tensor(loss_rate_target).float() + self.losses = [] - super().__init__(model, **kwargs) - self.loss_func0 = loss_func0 - self.loss_func1 = loss_func1 - self.loss_agg = loss_agg + def training_step(self, *args, **kwargs): + loss = super().training_step(*args, **kwargs) + self.adjust_train_loss(loss) + return loss - def _forward_step(self, batch, batch_idx, label=""): - x, y = batch - y_pred, y0, y1 = self.model(x) - loss = self.loss_func(y_pred, y) - loss0 = self.loss_func0(y0, x) - loss1 = self.loss_func1(y1, x) - self.log(f"{label}{'_' if label else ''}loss", loss) - return self.loss_agg(loss, loss0, loss1) + def adjust_train_loss(self, loss): + self.loss_optimizer.tell(self.train_loss.params, loss.item()) + self.train_loss.params = self.loss_optimizer.ask()