diff --git a/.gitignore b/.gitignore index e037d8f..d9c2d02 100644 --- a/.gitignore +++ b/.gitignore @@ -164,6 +164,7 @@ cython_debug/ .jj scratchpad.ipynb datasets/ +explore/ lightning_logs/ logs/ wandb/ diff --git a/.mailmap b/.mailmap new file mode 100644 index 0000000..8b2a52b --- /dev/null +++ b/.mailmap @@ -0,0 +1 @@ +Cian Hughes diff --git a/poetry.lock b/poetry.lock index bde4d1c..289607b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -412,6 +412,28 @@ jinxed = {version = ">=1.1.0", markers = "platform_system == \"Windows\""} six = ">=1.9.0" wcwidth = ">=0.1.4" +[[package]] +name = "bokeh" +version = "3.4.1" +description = "Interactive plots and applications in the browser from Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "bokeh-3.4.1-py3-none-any.whl", hash = "sha256:1e3c502a0a8205338fc74dadbfa321f8a0965441b39501e36796a47b4017b642"}, + {file = "bokeh-3.4.1.tar.gz", hash = "sha256:d824961e4265367b0750ce58b07e564ad0b83ca64b335521cd3421e9b9f10d89"}, +] + +[package.dependencies] +contourpy = ">=1.2" +Jinja2 = ">=2.9" +numpy = ">=1.16" +packaging = ">=16.8" +pandas = ">=1.2" +pillow = ">=7.1.0" +PyYAML = ">=3.10" +tornado = ">=6.2" +xyzservices = ">=2021.09.1" + [[package]] name = "bpython" version = "0.24" @@ -636,6 +658,25 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "colorcet" +version = "3.1.0" +description = "Collection of perceptually uniform colormaps" +optional = false +python-versions = ">=3.7" +files = [ + {file = "colorcet-3.1.0-py3-none-any.whl", hash = "sha256:2a7d59cc8d0f7938eeedd08aad3152b5319b4ba3bcb7a612398cc17a384cb296"}, + {file = "colorcet-3.1.0.tar.gz", hash = "sha256:2921b3cd81a2288aaf2d63dbc0ce3c26dcd882e8c389cc505d6886bf7aa9a4eb"}, +] + +[package.extras] +all = ["colorcet[doc]", "colorcet[examples]", "colorcet[tests-extra]", "colorcet[tests]"] +doc = ["colorcet[examples]", "nbsite (>=0.8.4)", "sphinx-copybutton"] +examples = ["bokeh", "holoviews", "matplotlib", "numpy"] +tests = ["packaging", "pre-commit", "pytest (>=2.8.5)", "pytest-cov"] +tests-examples = ["colorcet[examples]", "nbval"] +tests-extra = ["colorcet[tests]", "pytest-mpl"] + [[package]] name = "colorlog" version = "6.8.2" @@ -1392,6 +1433,43 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "holoviews" +version = "1.18.3" +description = "Stop plotting your data - annotate your data and let it visualize itself." +optional = false +python-versions = ">=3.9" +files = [ + {file = "holoviews-1.18.3-py2.py3-none-any.whl", hash = "sha256:b94b96560b64a84c07e89115aaf9b226e6009684800ec84d3c88cbad122c0c46"}, + {file = "holoviews-1.18.3.tar.gz", hash = "sha256:578e30e89d72754f97a83ebe08198fec8e87cc7e49b25b9f31ec393f939ca500"}, +] + +[package.dependencies] +colorcet = "*" +numpy = ">=1.0" +packaging = "*" +pandas = ">=0.20.0" +panel = ">=1.0" +param = ">=1.12.0,<3.0" +pyviz-comms = ">=0.7.4" + +[package.extras] +all = ["bokeh (>=3.1)", "cftime", "codecov", "contourpy", "cudf", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "graphviz", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "myst-nb (<1)", "nbconvert", "nbsite (>=0.8.4,<0.9.0)", "nbval", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "playwright", "plotly (>=4.0)", "pooch", "pre-commit", "pyarrow", "pytest", "pytest-cov", "pytest-github-actions-annotate-failures", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "ruff", "scikit-image", "scipy", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "streamz (>=0.5.0)", "xarray (>=0.10.4)"] +build = ["param (>=1.7.0)", "pyct (>=0.4.4)", "setuptools (>=30.3.0)"] +doc = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "graphviz", "ipython (>=5.4.0)", "matplotlib (>=3)", "myst-nb (<1)", "nbsite (>=0.8.4,<0.9.0)", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "selenium", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"] +examples = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ipython (>=5.4.0)", "matplotlib (>=3)", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"] +examples-tests = ["bokeh (>=3.1)", "cftime", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbval", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pyarrow", "scikit-image", "scipy", "shapely", "streamz (>=0.5.0)", "xarray (>=0.10.4)"] +lint = ["pre-commit", "ruff"] +notebook = ["ipython (>=5.4.0)", "notebook"] +recommended = ["bokeh (>=3.1)", "ipython (>=5.4.0)", "matplotlib (>=3)", "notebook"] +tests = ["bokeh (>=3.1)", "cftime", "contourpy", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "networkx", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "xarray (>=0.10.4)"] +tests-ci = ["codecov", "pytest-github-actions-annotate-failures"] +tests-core = ["bokeh (>=3.1)", "contourpy", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist"] +tests-gpu = ["bokeh (>=3.1)", "cftime", "contourpy", "cudf", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "networkx", "pillow", "plotly (>=4.0)", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "xarray (>=0.10.4)"] +tests-nb = ["nbval"] +ui = ["playwright", "pytest-playwright"] +unit-tests = ["bokeh (>=3.1)", "cftime", "contourpy", "dash (>=1.16)", "dask", "datashader (>=0.11.1)", "ffmpeg", "ibis-framework", "ipython (>=5.4.0)", "matplotlib (>=3)", "nbconvert", "netcdf4", "networkx", "notebook", "notebook (>=7.0)", "pillow", "plotly (>=4.0)", "pooch", "pre-commit", "pyarrow", "pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "ruff", "scikit-image", "scipy", "scipy (>=1.10)", "selenium", "shapely", "spatialpandas", "streamz (>=0.5.0)", "xarray (>=0.10.4)"] + [[package]] name = "httpcore" version = "1.0.5" @@ -1437,6 +1515,41 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "hvplot" +version = "0.10.0" +description = "A high-level plotting API for the PyData ecosystem built on HoloViews." +optional = false +python-versions = ">=3.8" +files = [ + {file = "hvplot-0.10.0-py3-none-any.whl", hash = "sha256:fe90ccb48163a6a62ae5bd6b008c2cb15cbf5b276f6ad6839ef5470b1c480d16"}, + {file = "hvplot-0.10.0.tar.gz", hash = "sha256:e87486a95bfe151ab52ef163a5e93d9cbd043992cf0b755ccadd2bf36fedd376"}, +] + +[package.dependencies] +bokeh = ">=1.0.0" +colorcet = ">=2" +holoviews = ">=1.11.0" +numpy = ">=1.15" +packaging = "*" +pandas = "*" +panel = ">=0.11.0" +param = ">=1.12.0,<3.0" + +[package.extras] +dev-extras = ["setuptools-scm (>=6)"] +doc = ["hvplot[examples]", "nbsite (>=0.8.4)", "sphinxext-rediraffe"] +examples = ["dask[dataframe] (>=2021.3.0)", "datashader (>=0.6.5)", "fugue[sql]", "geodatasets (>=2023.12.0)", "hvplot[fugue-sql]", "ibis-framework[duckdb]", "intake (>=0.6.5,<2.0.0)", "intake-parquet (>=0.2.3)", "intake-xarray (>=0.5.0)", "ipywidgets", "matplotlib", "networkx (>=2.6.3)", "notebook (>=5.4)", "numba (>=0.51.0)", "pillow (>=8.2.0)", "plotly", "polars", "pooch (>=1.6.0)", "s3fs (>=2022.1.0)", "scikit-image (>=0.17.2)", "scipy (>=1.5.3)", "selenium (>=3.141.0)", "streamz (>=0.3.0)", "xarray (>=0.18.2)", "xyzservices (>=2022.9.0)"] +examples-tests = ["hvplot[examples]", "hvplot[tests-nb]"] +fugue-sql = ["fugue-sql-antlr (>=0.2.0)", "jinja2", "qpd (>=0.4.4)", "sqlglot"] +geo = ["cartopy", "fiona", "geopandas", "geoviews (>=1.9.0)", "pyproj", "rasterio", "rioxarray", "spatialpandas (>=0.4.3)"] +graphviz = ["pygraphviz"] +hvdev = ["colorcet (>=0.0.1a1)", "datashader (>=0.0.1a1)", "holoviews (>=0.0.1a1)", "panel (>=0.0.1a1)", "param (>=0.0.1a1)", "pyviz-comms (>=0.0.1a1)"] +hvdev-geo = ["geoviews (>=0.0.1a1)"] +tests = ["fugue[sql]", "hvplot[fugue-sql]", "hvplot[tests-core]", "ibis-framework[duckdb]", "polars"] +tests-core = ["dask[dataframe]", "ipywidgets", "matplotlib", "parameterized", "plotly", "pooch", "pre-commit", "pytest", "pytest-cov", "ruff", "scipy", "xarray"] +tests-nb = ["nbval", "pytest-xdist"] + [[package]] name = "idna" version = "3.7" @@ -3001,6 +3114,78 @@ files = [ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] +[[package]] +name = "pandas" +version = "2.2.2" +description = "Powerful data structures for data analysis, time series, and statistics" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, + {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, + {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, + {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, + {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, + {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, + {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, + {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, + {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, + {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, + {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, + {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, + {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, + {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, + {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, + {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, + {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, + {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, +] +python-dateutil = ">=2.8.2" +pytz = ">=2020.1" +tzdata = ">=2022.7" + +[package.extras] +all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] +aws = ["s3fs (>=2022.11.0)"] +clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] +compression = ["zstandard (>=0.19.0)"] +computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +consortium-standard = ["dataframe-api-compat (>=0.1.7)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] +feather = ["pyarrow (>=10.0.1)"] +fss = ["fsspec (>=2022.11.0)"] +gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] +hdf5 = ["tables (>=3.8.0)"] +html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] +mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] +parquet = ["pyarrow (>=10.0.1)"] +performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] +plot = ["matplotlib (>=3.6.3)"] +postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] +pyarrow = ["pyarrow (>=10.0.1)"] +spss = ["pyreadstat (>=1.2.0)"] +sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] +xml = ["lxml (>=4.9.2)"] + [[package]] name = "pandocfilters" version = "1.5.1" @@ -3012,6 +3197,64 @@ files = [ {file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"}, ] +[[package]] +name = "panel" +version = "1.4.4" +description = "The powerful data exploration & web app framework for Python." +optional = false +python-versions = ">=3.9" +files = [ + {file = "panel-1.4.4-py3-none-any.whl", hash = "sha256:b49bb9676567b0c0730bf69348c057247080811aec56364dd4fcfba80e5e09a0"}, + {file = "panel-1.4.4.tar.gz", hash = "sha256:659e9fc5b495e6519c5d07e8148fa5eeed9bc648356ec83fc299381ba5a726ef"}, +] + +[package.dependencies] +bleach = "*" +bokeh = ">=3.4.0,<3.5.0" +linkify-it-py = "*" +markdown = "*" +markdown-it-py = "*" +mdit-py-plugins = "*" +pandas = ">=1.2" +param = ">=2.1.0,<3.0" +pyviz-comms = ">=2.0.0" +requests = "*" +tqdm = ">=4.48.0" +typing-extensions = "*" +xyzservices = ">=2021.09.1" + +[package.extras] +all = ["aiohttp", "altair", "anywidget", "channels", "croniter", "dask-expr", "datashader", "diskcache", "django (<4)", "fastparquet", "flake8", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipython (>=7.0)", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "jupyter-server", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "nbval", "networkx (>=2.5)", "numba (<0.58)", "numpy", "pandas (<2.1.0)", "pandas (>=1.3)", "parameterized", "pillow", "playwright", "plotly", "plotly (>=4.0)", "pre-commit", "psutil", "pydeck", "pygraphviz", "pyinstrument (>=4.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "python-graphviz", "pyvista", "reacton", "scikit-image", "scikit-learn", "scipy", "seaborn", "streamz", "textual", "tomli", "twine", "vega-datasets", "vtk", "watchfiles", "xarray", "xgboost"] +all-pip = ["aiohttp", "altair", "anywidget", "channels", "croniter", "dask-expr", "datashader", "diskcache", "django (<4)", "fastparquet", "flake8", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipython (>=7.0)", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "jupyter-server", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "nbval", "networkx (>=2.5)", "numba (<0.58)", "numpy", "pandas (<2.1.0)", "pandas (>=1.3)", "parameterized", "pillow", "playwright", "plotly", "plotly (>=4.0)", "pre-commit", "psutil", "pydeck", "pyinstrument (>=4.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-playwright", "pytest-rerunfailures", "pytest-xdist", "pyvista", "reacton", "scikit-image", "scikit-learn", "scipy", "seaborn", "streamz", "textual", "tomli", "twine", "vega-datasets", "vtk", "watchfiles", "xarray", "xgboost"] +build = ["bleach", "bokeh (>=3.4.0,<3.5.0)", "cryptography (<39)", "markdown", "packaging", "param (>=2.0.0)", "pyviz-comms (>=2.0.0)", "requests", "setuptools (>=42)", "tqdm (>=4.48.0)", "urllib3 (<2.0)"] +doc = ["holoviews (>=1.16.0)", "jupyterlab", "lxml", "matplotlib", "nbsite (>=0.8.4)", "pandas (<2.1.0)", "pillow", "plotly"] +examples = ["aiohttp", "altair", "channels", "croniter", "dask-expr", "datashader", "django (<4)", "fastparquet", "folium", "graphviz", "holoviews (>=1.16.0)", "hvplot", "ipyleaflet", "ipympl", "ipyvolume", "ipyvuetify", "ipywidgets", "ipywidgets-bokeh", "jupyter-bokeh (>=3.0.7)", "networkx (>=2.5)", "plotly (>=4.0)", "pydeck", "pygraphviz", "pyinstrument (>=4.0)", "python-graphviz", "pyvista", "reacton", "scikit-image", "scikit-learn", "seaborn", "streamz", "textual", "vega-datasets", "vtk", "xarray", "xgboost"] +recommended = ["holoviews (>=1.16.0)", "jupyterlab", "matplotlib", "pillow", "plotly"] +tests = ["altair", "anywidget", "diskcache", "flake8", "folium", "holoviews (>=1.16.0)", "ipympl", "ipython (>=7.0)", "ipyvuetify", "ipywidgets-bokeh", "nbval", "numba (<0.58)", "numpy", "pandas (>=1.3)", "parameterized", "pre-commit", "psutil", "pytest", "pytest-asyncio", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "reacton", "scipy", "textual", "twine", "watchfiles"] +tests-core = ["altair", "anywidget", "diskcache", "flake8", "folium", "holoviews (>=1.16.0)", "ipython (>=7.0)", "nbval", "numpy", "pandas (>=1.3)", "parameterized", "pre-commit", "psutil", "pytest", "pytest-asyncio", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "scipy", "textual", "watchfiles"] +ui = ["jupyter-server", "playwright", "pytest-playwright", "tomli"] + +[[package]] +name = "param" +version = "2.1.0" +description = "Make your Python code clearer and more reliable by declaring Parameters." +optional = false +python-versions = ">=3.8" +files = [ + {file = "param-2.1.0-py3-none-any.whl", hash = "sha256:f31d3745d227347d29b5868c4e4e3077df07463889b91d3bb28e634fde211e1c"}, + {file = "param-2.1.0.tar.gz", hash = "sha256:a7b30b08b547e2b78b02aeba6ed34e3c6a638f8e4824a76a96ffa2d7cf57e71f"}, +] + +[package.extras] +all = ["param[doc]", "param[lint]", "param[tests-full]"] +doc = ["nbsite (==0.8.4)", "param[examples]", "sphinx-remove-toctrees"] +examples = ["aiohttp", "pandas", "panel"] +lint = ["flake8", "pre-commit"] +tests = ["coverage[toml]", "pytest", "pytest-asyncio"] +tests-deser = ["odfpy", "openpyxl", "pyarrow", "tables", "xlrd"] +tests-examples = ["nbval", "param[examples]", "pytest (<8.1)", "pytest-asyncio", "pytest-xdist"] +tests-full = ["cloudpickle", "gmpy", "ipython", "jsonschema", "nest-asyncio", "numpy", "pandas", "param[tests-deser]", "param[tests-examples]", "param[tests]"] + [[package]] name = "parso" version = "0.8.4" @@ -3310,6 +3553,54 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyarrow" +version = "16.1.0" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, + {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, + {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, + {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, + {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, + {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, + {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, + {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, + {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, + {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, + {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, + {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, + {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, + {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, + {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, + {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, + {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + [[package]] name = "pycparser" version = "2.22" @@ -3444,6 +3735,36 @@ extra = ["bitsandbytes (==0.41.0)", "hydra-core (>=1.0.5)", "jsonargparse[signat strategies = ["deepspeed (>=0.8.2,<=0.9.3)"] test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"] +[[package]] +name = "pytz" +version = "2024.1" +description = "World timezone definitions, modern and historical" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, + {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, +] + +[[package]] +name = "pyviz-comms" +version = "3.0.2" +description = "A JupyterLab extension for rendering HoloViz content." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyviz_comms-3.0.2-py3-none-any.whl", hash = "sha256:31541b976a21b7738557c3ea23bd8e44e94e736b9ed269570dcc28db4449d7e3"}, + {file = "pyviz_comms-3.0.2.tar.gz", hash = "sha256:3167df932656416c4bd711205dad47e986a3ebae1f316258ddc26f9e01513ef7"}, +] + +[package.dependencies] +param = "*" + +[package.extras] +all = ["pyviz-comms[build]", "pyviz-comms[tests]"] +build = ["jupyterlab (>=4.0,<5.0)", "keyring", "rfc3986", "setuptools (>=40.8.0)", "twine"] +tests = ["flake8", "pytest"] + [[package]] name = "pywin32" version = "306" @@ -4823,6 +5144,17 @@ files = [ mypy-extensions = ">=0.3.0" typing-extensions = ">=3.7.4" +[[package]] +name = "tzdata" +version = "2024.1" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +files = [ + {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, + {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, +] + [[package]] name = "uc-micro-py" version = "1.0.3" @@ -5009,6 +5341,17 @@ files = [ {file = "widgetsnbextension-4.0.10.tar.gz", hash = "sha256:64196c5ff3b9a9183a8e699a4227fb0b7002f252c814098e66c4d1cd0644688f"}, ] +[[package]] +name = "xyzservices" +version = "2024.4.0" +description = "Source of XYZ tiles providers" +optional = false +python-versions = ">=3.8" +files = [ + {file = "xyzservices-2024.4.0-py3-none-any.whl", hash = "sha256:b83e48c5b776c9969fffcfff57b03d02b1b1cd6607a9d9c4e7f568b01ef47f4c"}, + {file = "xyzservices-2024.4.0.tar.gz", hash = "sha256:6a04f11487a6fb77d92a98984cd107fbd9157fd5e65f929add9c3d6e604ee88c"}, +] + [[package]] name = "yarl" version = "1.9.4" @@ -5115,4 +5458,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "74e84f457ce411d0612153d658f0e43f92f140297fcc1812501b74e7757fedc3" +content-hash = "f607472660b04b7f6f5d49a4561730f788a46f0d1e0176322e872111b00481cd" diff --git a/pyproject.toml b/pyproject.toml index a63d622..1627157 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ polars = "^0.20.28" jupyter = "^1.0.0" safetensors = "^0.4.3" alive-progress = "^3.1.5" +hvplot = "^0.10.0" +pyarrow = "^16.1.0" [build-system] diff --git a/symbolic_nn_tests/experiment2/model.py b/symbolic_nn_tests/experiment2/model.py index 890a1cd..c86bf76 100644 --- a/symbolic_nn_tests/experiment2/model.py +++ b/symbolic_nn_tests/experiment2/model.py @@ -40,12 +40,12 @@ class Model(nn.Module): x0, x1 = x y0 = self.encode_x0(x0) y1 = self.encode_x1(x1) - x = torch.cat([y0, y1], dim=1) - y = self.ff(x) + y = torch.cat([y0, y1], dim=1) + y = self.ff(y) if self.return_module_y: - return y, y0, y1 + return x, (y, y0, y1) else: - return y + return x, y # This is just a quick, lazy way to ensure all models are trained on the same dataset @@ -59,7 +59,95 @@ def get_singleton_dataset(): ) -def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs): +def smooth_l1_loss(out, y): + _, y_pred = out + return nn.functional.smooth_l1_loss(y_pred, y) + + +def sech(x): + return torch.reciprocal(torch.cosh(x)) + + +def linear_fit(x, y): + mean_x = torch.mean(x) + mean_y = torch.mean(y) + cov_xy = torch.mean(x * y) - (mean_x * mean_y) + var_x = torch.mean(x * x) - (mean_x * mean_x) + m = cov_xy / var_x + c = mean_y - (m * mean_x) + return m, c + + +def line(x, m, c): + return (m * x) + c + + +def linear_residuals(x, y, m, c): + return y - line(x, m, c) + + +def semantic_loss(x, y_pred, w, a): + m, c = linear_fit(x, y_pred) + residuals = linear_residuals(x, y_pred, m, c) + scaled_residuals = residuals * sech(w * x) + slope_penalty = torch.nn.functional.softmax(a * m, dim=0) + loss = torch.mean(scaled_residuals**2) + torch.mean(slope_penalty) + return loss + + +def loss(out, y): + x, y_pred = out + x0, x1 = x + + # Here, we want to make semantic use of the differential electronegativity of the molecule + # so start by calculating that + mean_electronegativities = torch.tensor( + [i[:, 3].mean() for i in x0], dtype=torch.float32 + ).to(y_pred.device) + diff_electronegativity = ( + torch.tensor( + [ + (i[:, 3] - mean).abs().sum() + for i, mean in zip(x0, mean_electronegativities) + ], + dtype=torch.float32, + ) + * 4.0 + ).to(y_pred.device) + + # Then, we need to get a linear best fit on that. Our semantic info is based on a graph of + # En (y) vs differential electronegativity on the x vs y axes, so y_pred is y here + m, c = linear_fit(diff_electronegativity, y_pred) + + # To start with, we want to calculate a penalty based on deviation from a linear relationship + # Scaling is being based on 1/sech(w*r) as this increases multiplier as deviation grows. + # `w` was selected based on noting that the residual spread before eneg scaling was about 25; + # enegs were normalised as x/4, so we want to incentivize a spread of about 25/4~=6, and w=0.2 + # causes the penalty function to cross 2 at just over 6. Yes, that's a bit arbitrary but we're + # just steering the model not applying hard constraints to it shold be fine. + residual_penalty = ( + ( + linear_residuals(diff_electronegativity, y_pred, m, c) + / sech(0.2 * diff_electronegativity) + ) + .abs() + .float() + .mean() + ) + + # We also need to calculate a penalty that incentivizes a positive slope. For this, im using softmax + # to scale the slope as it will penalise negative slopes while not just creating a reward hack for + # maximizing slope. The softmax function approximates 1 from about 5 onwards, so if we multiply m by + # 500, then our penalty should be almost minimised for any slope above 0.01 and maximised below 0.01. + # This should suffice for incentivizing the model to favour positive slopes. + slope_penalty = (torch.nn.functional.softmax(-m * 500.0) + 1).mean() + + # Finally, let's get a smooth L1 loss and scale it based on these penalty functions + return nn.functional.smooth_l1_loss(y_pred, y) * residual_penalty * slope_penalty + + +# def main(loss_func=smooth_l1_loss, logger=None, **kwargs): +def main(loss_func=loss, logger=None, **kwargs): import lightning as L from symbolic_nn_tests.train import TrainingWrapper @@ -72,7 +160,7 @@ def main(loss_func=nn.functional.smooth_l1_loss, logger=None, **kwargs): train, val, test = get_singleton_dataset() lmodel = TrainingWrapper(Model(), loss_func=loss_func) lmodel.configure_optimizers(optimizer=torch.optim.NAdam, **kwargs) - trainer = L.Trainer(max_epochs=20, logger=logger) + trainer = L.Trainer(max_epochs=10, logger=logger, num_sanity_val_steps=0) trainer.fit(model=lmodel, train_dataloaders=train, val_dataloaders=val) trainer.test(dataloaders=test) diff --git a/symbolic_nn_tests/experiment2/semantic_loss.py b/symbolic_nn_tests/experiment2/semantic_loss.py index aaa3285..f52842b 100644 --- a/symbolic_nn_tests/experiment2/semantic_loss.py +++ b/symbolic_nn_tests/experiment2/semantic_loss.py @@ -1 +1,16 @@ import torch + + +# TODO: implement semantic loss functions +# These functions would enforce relationships we expect to be present in the results +# if the model is performing correctly. The first ones to implement would be: +# - En should be proportional to molecular mass +# - Bigger molecule = more dof and more strain +# - En should be inversely proportional to the differential in electronegativity +# - Higher electroneg diff = more stable molecule +# - calc diff as `torch.sum( torch.abs( electronegativity - electronegativity.mean() ) )` +# Best way to enforce this relationship would probably be to apply a multiplier based on +# a normalized sigmoid curve. This would incentivize the model to ensure slope has correct sign +# without creating a reward hack for maximizing/minimizing m and preventing exploding gradients. +# It also allows us to avoid the assumption of linearity: we only care about the direction of +# proportionality. diff --git a/symbolic_nn_tests/experiment2/train.py b/symbolic_nn_tests/experiment2/train.py new file mode 100644 index 0000000..77e3f74 --- /dev/null +++ b/symbolic_nn_tests/experiment2/train.py @@ -0,0 +1,20 @@ +from symbolic_nn_tests.train import TrainingWrapper + + +class SemanticModuleTrainingWrapper(TrainingWrapper): + def __init__(self, model, *args, loss_func0, loss_func1, loss_agg, **kwargs): + assert len(args) == 0 + + super().__init__(model, **kwargs) + self.loss_func0 = loss_func0 + self.loss_func1 = loss_func1 + self.loss_agg = loss_agg + + def _forward_step(self, batch, batch_idx, label=""): + x, y = batch + y_pred, y0, y1 = self.model(x) + loss = self.loss_func(y_pred, y) + loss0 = self.loss_func0(y0, x) + loss1 = self.loss_func1(y1, x) + self.log(f"{label}{'_' if label else ''}loss", loss) + return self.loss_agg(loss, loss0, loss1)