From 674175b260c955c29c6a6e857a847259221f6be0 Mon Sep 17 00:00:00 2001 From: Cian-H Date: Mon, 13 May 2024 16:05:26 +0100 Subject: [PATCH] Added trainer and made dataset hot-swappable --- poetry.lock | 136 +++++++++++++++++++++++++++++++- pyproject.toml | 1 + symbolic_nn_tests/dataloader.py | 3 +- symbolic_nn_tests/trainer.py | 19 +++++ 4 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 symbolic_nn_tests/trainer.py diff --git a/poetry.lock b/poetry.lock index 855997f..9c8c5d3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + [[package]] name = "aenum" version = "3.1.12" @@ -1097,6 +1108,64 @@ files = [ docs = ["Sphinx", "furo"] test = ["objgraph", "psutil"] +[[package]] +name = "grpcio" +version = "1.63.0" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, + {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, + {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, + {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, + {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, + {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, + {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, + {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, + {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, + {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, + {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, + {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, + {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, + {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, + {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, + {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, + {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, + {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, + {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, + {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, + {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, + {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, + {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, + {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, + {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, + {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, + {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, + {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, + {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, + {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, + {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.63.0)"] + [[package]] name = "idna" version = "3.7" @@ -1585,6 +1654,21 @@ babel = ["Babel"] lingua = ["lingua"] testing = ["pytest"] +[[package]] +name = "markdown" +version = "3.6" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markdown-it-py" version = "2.1.0" @@ -3300,6 +3384,39 @@ files = [ {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, ] +[[package]] +name = "tensorboard" +version = "2.16.2" +description = "TensorBoard lets you watch Tensors Flow" +optional = false +python-versions = ">=3.9" +files = [ + {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, +] + +[package.dependencies] +absl-py = ">=0.4" +grpcio = ">=1.48.2" +markdown = ">=2.6.8" +numpy = ">=1.12.0" +protobuf = ">=3.19.6,<4.24.0 || >4.24.0" +setuptools = ">=41.0.0" +six = ">1.9" +tensorboard-data-server = ">=0.7.0,<0.8.0" +werkzeug = ">=1.0.1" + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +description = "Fast data loading for TensorBoard" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, + {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, + {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, +] + [[package]] name = "timg" version = "1.1.6" @@ -3625,6 +3742,23 @@ files = [ {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, ] +[[package]] +name = "werkzeug" +version = "3.0.3" +description = "The comprehensive WSGI web application library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, +] + +[package.dependencies] +MarkupSafe = ">=2.1.1" + +[package.extras] +watchdog = ["watchdog (>=2.3)"] + [[package]] name = "yarl" version = "1.9.4" @@ -3731,4 +3865,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b2be2e46a3aed8b56d958ea86ba3973c1509f069c3f2d82df09f418e1704a708" +content-hash = "035f9d736293c9021a1dcdf67fd794bcc9034180dd2f458d9132e732daca06d9" diff --git a/pyproject.toml b/pyproject.toml index b38e79c..7f4e178 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ ipython = "^8.24.0" matplotlib-backend-kitty = "^2.1.2" euporie = "^2.8.2" ipykernel = "^6.29.4" +tensorboard = "^2.16.2" [build-system] diff --git a/symbolic_nn_tests/dataloader.py b/symbolic_nn_tests/dataloader.py index 1f2ce0f..3abbc36 100644 --- a/symbolic_nn_tests/dataloader.py +++ b/symbolic_nn_tests/dataloader.py @@ -10,13 +10,14 @@ PROJECT_ROOT = Path(__file__).parent.parent def get_dataset( split: (float, float, float) = (0.7, 0.1, 0.2), + dataset=Caltech256, batch_size: int = 128, drop_last: bool = False, **kwargs, ): _kwargs = {"transform": ToTensor()} _kwargs.update(kwargs) - ds = Caltech256(PROJECT_ROOT / "datasets/", download=True, **_kwargs) + ds = dataset(PROJECT_ROOT / "datasets/", download=True, **_kwargs) train, test, val = ( BatchSampler(i, batch_size, drop_last) for i in random_split(ds, split) ) diff --git a/symbolic_nn_tests/trainer.py b/symbolic_nn_tests/trainer.py new file mode 100644 index 0000000..3bc781b --- /dev/null +++ b/symbolic_nn_tests/trainer.py @@ -0,0 +1,19 @@ +from torch import nn, optim +import lightning as L + + +class Trainer(L.LighningModule): + def __init__(self, model): + super().__init__() + self.model = model + + def training_step(self, batch, batch_idx): + x, y = batch + y_pred = self.model(x) + loss = nn.functional.mse_loss(y_pred, y) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), lr=1e-3) + return optimizer