diff --git a/src/miemfp.jl b/src/miemfp.jl index aa405df..873d6a3 100755 --- a/src/miemfp.jl +++ b/src/miemfp.jl @@ -68,7 +68,7 @@ Corrects refractive index and attenuation coefficient values to account for the mean free-path effect on the free conductance electrons in the particles (translated from Haiss et als FORTRAN implementation) """ - function mfp(fv::Float64, wavel::Float64, radcor::Float64, omp::Float64, +function mfp(fv::Float64, wavel::Float64, radcor::Float64, omp::Float64, om0::Float64, rn::Float64, rk::Float64)::Tuple{Float64,Float64,Float64} om, om0r, om_sq, _, omp_sq, om0r_sq, om_sq_plus_om0_sq = @@ -185,8 +185,8 @@ function bhmie(x::Float64, refrel::ComplexF64, nang::UInt32 y::ComplexF64 = x * refrel nstop::UInt32 = UInt32(round(x + 4.0 * cbrt(x) + 2.0)) nn::UInt32 = UInt32(round(max(nstop, abs(y)) + 14)) - amu::Vector{Float64} = cos.((1.570796327 / Float64(nang - 1)).*Float64.(0:nang-1)) - + amu::Vector{Float64} = cos.((1.570796327 / Float64(nang - 1)) .* Float64.(0:nang-1)) + d = Vector{ComplexF64}(undef, nn) d[nn] = ComplexF64(0.0, 0.0) @simd for n in nn:-1:2 @@ -273,7 +273,7 @@ end const _cache = Dict{UInt64,InterpolationPair}() "Helper function to get the interpolation objects for rn and rk" - function _get_rnrk_interp_objects(refcore::Array{Float64,2})::InterpolationPair +function _get_rnrk_interp_objects(refcore::Array{Float64,2})::InterpolationPair refcore_hash = hash(refcore) if !haskey(_cache, refcore_hash) _cache[refcore_hash] = InterpolationPair(LinearInterpolation(refcore[:, 1], refcore[:, 2]), LinearInterpolation(refcore[:, 1], refcore[:, 3])) @@ -308,7 +308,7 @@ struct _mfp om0::Float64 end -function(p::_mfp)(wavel::Float64, rn::Float64, rk::Float64)::ComplexF64 +function (p::_mfp)(wavel::Float64, rn::Float64, rk::Float64)::ComplexF64 refre1, refim1, _ = mfp(p.fv, wavel, p.radcore, p.omp, p.om0, rn, rk) return ComplexF64(refre1, refim1)::ComplexF64 end diff --git a/src/nanoconc.jl b/src/nanoconc.jl index 088f6da..b0a9728 100755 --- a/src/nanoconc.jl +++ b/src/nanoconc.jl @@ -269,11 +269,11 @@ struct q_to_sigma geometric_cross_section::Vector{Float64} function q_to_sigma(diameter::Vector{Float64})::q_to_sigma - new(π .* ((diameter ./ 2).^2)) - end + new(π .* ((diameter ./ 2) .^ 2)) + end end -function (p::q_to_sigma)(q::Matrix{Float64})::Array{Float64, 3} +function (p::q_to_sigma)(q::Matrix{Float64})::Array{Float64,3} # Calculate the extinction cross-section hcat([q .* x for x in p.geometric_cross_section]) end diff --git a/test/build_ffi.sh b/test/build_ffi.sh index 895ee79..7efd9c0 100755 --- a/test/build_ffi.sh +++ b/test/build_ffi.sh @@ -13,7 +13,7 @@ codebases=("bhmie-f.zip" "bhmie-c.zip") # if bhmie_dir containers bhmie-c/bhmie.so, bhmie-f/bhmie.so, and bhmie-f/bhmie_f77.so then we can skip the build bhmie_dir=$1 if [ -f $bhmie_dir/bhmie-c/bhmie.so ] && [ -f $bhmie_dir/bhmie-f/bhmie.so ] && [ -f $bhmie_dir/bhmie-f/bhmie_f77.so ]; then - echo "bhmie-c/bhmie.so, bhmie-f/bhmie.so, and bhmie-f/bhmie_f77.so already exist. Skipping build." + echo "FFI shared objects already exist. Skipping build." exit 0 fi diff --git a/test/ffi_wraps.jl b/test/ffi_wraps.jl index c405630..91dd455 100644 --- a/test/ffi_wraps.jl +++ b/test/ffi_wraps.jl @@ -1,7 +1,5 @@ module FFIWraps -include("../anchors.jl") - if !@isdefined TEST_DIR include("../anchors.jl") import .Anchors: TEST_DIR @@ -45,8 +43,8 @@ function bhmie_fortran(x::Float32, refrel::ComplexF32, nang::Int32, s1::Vector{C # Call the Fortran subroutine ccall((:bhmie_, "$BHMIELIBS_DIR/bhmie-f/bhmie.so"), Cvoid, - (Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}), - x, refrel, nang, s1, s2, qext, qsca, qback, gsca) + (Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}), + x, refrel, nang, s1, s2, qext, qsca, qback, gsca) # Return the modified values return qext[], qsca[], qback[], gsca[] @@ -61,8 +59,8 @@ function bhmie_fortran77(x::Float32, refrel::ComplexF32, nang::Int32, s1::Vector # Call the Fortran subroutine ccall((:bhmie_, "$BHMIELIBS_DIR/bhmie-f/bhmie.so"), Cvoid, - (Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}), - x, refrel, nang, s1, s2, qext, qsca, qback, gsca) + (Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}), + x, refrel, nang, s1, s2, qext, qsca, qback, gsca) # Return the modified values return qext[], qsca[], qback[], gsca[] diff --git a/test/miemfp_tests.jl b/test/miemfp_tests.jl index c478faf..85e3e44 100644 --- a/test/miemfp_tests.jl +++ b/test/miemfp_tests.jl @@ -4,25 +4,17 @@ using PropCheck using Debugger using PyCall +if !@isdefined TestUtils + include("testutils.jl") +end -include("../anchors.jl") +TestUtils.init_pyenv() +TestUtils.singleton_include("../anchors.jl", :Anchors, @__MODULE__) import .Anchors: TEST_DIR, SRC_DIR, ROOT_DIR -if !@isdefined TestUtils - include(joinpath(TEST_DIR, "testutils.jl")) -end -if !@isdefined miemfp - include(joinpath(SRC_DIR, "miemfp.jl")) -end -if !@isdefined FFIWraps - include(joinpath(TEST_DIR, "ffi_wraps.jl")) -end - - -# Set up the Python environment -run(`$ROOT_DIR/setup_venv.sh`) -ENV["PYTHON"] = joinpath(ROOT_DIR, ".venv/bin/python") +TestUtils.singleton_include(joinpath(SRC_DIR, "miemfp.jl"), :miemfp, @__MODULE__) +TestUtils.singleton_include(joinpath(TEST_DIR, "ffi_wraps.jl"), :FFIWraps, @__MODULE__) @pyinclude(joinpath(TEST_DIR, "miemfp_tests.py")) diff --git a/test/miemfp_tests.py b/test/miemfp_tests.py index 8e5962b..c2f9c3e 100644 --- a/test/miemfp_tests.py +++ b/test/miemfp_tests.py @@ -1,9 +1,14 @@ -from typing import List, Tuple import asyncio -import numpy as np -from hypothesis import errors, given, settings, strategies as st # type: ignore +from typing import List, Tuple -def compare_bhmie_functions(f1, f2, event1: asyncio.Event, event2: asyncio.Event) -> Tuple[bool, str]: +import numpy as np +from hypothesis import errors, given, settings # type: ignore +from hypothesis import strategies as st # type: ignore + + +def compare_bhmie_functions( + f1, f2, event1: asyncio.Event, event2: asyncio.Event +) -> Tuple[bool, str]: async def async_closure( x: float, cxref: Tuple[float, float], @@ -13,24 +18,41 @@ def compare_bhmie_functions(f1, f2, event1: asyncio.Event, event2: asyncio.Event cxref = complex(*cxref) cxs1 = [complex(*c) for c in cxs1] cxs2 = [complex(*c) for c in cxs2] - + # This is to ensure that only one instance of each function is running at a time # to avoid memory issues in the FFI code await event1.wait() f1_result = f1(x, cxref, 2, cxs1, cxs2)[:2] await event2.wait() f2_result = f2(x, cxref, 2, cxs1, cxs2)[:2] - + return np.all(np.isclose(f1_result, f2_result)) - + @settings(deadline=None) @given( # Must be bigger than an atom but still nanoscale x=st.floats(min_value=0.1, max_value=100), # Refractive indeces must be within a physically reasonable range - cxref=st.tuples(st.floats(min_value=0.1, max_value=4.0), st.floats(min_value=0.1, max_value=4.0)), - cxs1=st.lists(st.tuples(st.floats(min_value=0.1, allow_infinity=False), st.floats(min_value=0.1, allow_infinity=False)), min_size=10, max_size=100), - cxs2=st.lists(st.tuples(st.floats(min_value=0.1, allow_infinity=False), st.floats(min_value=0.1, allow_infinity=False)), min_size=10, max_size=100), + cxref=st.tuples( + st.floats(min_value=0.1, max_value=4.0), + st.floats(min_value=0.1, max_value=4.0), + ), + cxs1=st.lists( + st.tuples( + st.floats(min_value=0.1, allow_infinity=False), + st.floats(min_value=0.1, allow_infinity=False), + ), + min_size=10, + max_size=100, + ), + cxs2=st.lists( + st.tuples( + st.floats(min_value=0.1, allow_infinity=False), + st.floats(min_value=0.1, allow_infinity=False), + ), + min_size=10, + max_size=100, + ), ) def sync_closure( x: float, @@ -39,11 +61,11 @@ def compare_bhmie_functions(f1, f2, event1: asyncio.Event, event2: asyncio.Event cxs2: List[Tuple[float, float]], ) -> bool: assert asyncio.run(async_closure(x, cxref, cxs1, cxs2)) - + try: sync_closure() return True, "Test passed" except AssertionError as e: return False, f"AssertionError: {str(e)}" except errors.HypothesisException as e: - return False, f"HypothesisException: {str(e)}" \ No newline at end of file + return False, f"HypothesisException: {str(e)}" diff --git a/test/runtests.jl b/test/runtests.jl index 52ccef4..1a49e25 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,9 +3,12 @@ using Test include("testutils.jl") include("../src/nanoconc.jl") +# Set up the Python environment +TestUtils.init_pyenv() + # include("nanoconc_tests.jl") -include("miemfp_tests.jl") +TestUtils.singleton_include("miemfp_tests.jl", :miemfp, @__MODULE__) # include("quantumcalc_tests.jl") -include("benchmarks.jl") +TestUtils.singleton_include("benchmarks.jl", :Benchmarks, @__MODULE__) Benchmarks.bench_vs_ffi() \ No newline at end of file diff --git a/test/testutils.jl b/test/testutils.jl index b98c7ea..3ac9862 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -3,16 +3,39 @@ using Serialization using Test +function singleton_include(filepath::String, m::Symbol, calling_module::Module) + try + target = getproperty(calling_module, m) + if !isdefined(target, :Module) + raise("Module $m not defined in $calling_module") + end + catch + Base.include(calling_module, filepath) + end +end + +singleton_include("../anchors.jl", :Anchors, @__MODULE__) + +import .Anchors: TEST_DIR, SRC_DIR, ROOT_DIR + +function init_pyenv() + if "PYTHON" in keys(ENV) && ENV["PYTHON"] != joinpath(ROOT_DIR, ".venv/bin/python") + run(`$ROOT_DIR/setup_venv.sh`) + ENV["PYTHON"] = joinpath(ROOT_DIR, ".venv/bin/python") + end +end + + function fieldvalues(obj) [getfield(obj, f) for f in fieldnames(typeof(obj))] end -deep_compare(a, b; rtol::Real=sqrt(eps()), atol::Real=0.) = a == b # base case: use == -deep_compare(a::Union{AbstractFloat, Complex}, b::Union{AbstractFloat, Complex}; rtol::Real=sqrt(eps()), atol::Real=0.) = isapprox(a, b) # for floats, use isapprox -deep_compare(a::Union{Array{AbstractFloat}, Array{Complex}}, b::Union{Array{AbstractFloat}, Array{Complex}}; rtol::Real=sqrt(eps()), atol::Real=0.) = isapprox(a, b) # for arrays of floats, use isapprox element-wise -deep_compare(a::AbstractArray, b::AbstractArray; rtol::Real=sqrt(eps()), atol::Real=0.) = all(deep_compare.(a, b; rtol, atol)) # for arrays of other types, recurse -deep_compare(a::Tuple, b::Tuple; rtol::Real=sqrt(eps()), atol::Real=0.) = all(deep_compare.(a, b; rtol, atol)) # for tuples, recurse -deep_compare(a::T, b::T; rtol::Real=sqrt(eps()), atol::Real=0.) where {T <: Any} = deep_compare(fieldvalues(a), fieldvalues(b); rtol, atol) # for composite types, recurse +deep_compare(a, b; rtol::Real=sqrt(eps()), atol::Real=0.0) = a == b # base case: use == +deep_compare(a::Union{AbstractFloat,Complex}, b::Union{AbstractFloat,Complex}; rtol::Real=sqrt(eps()), atol::Real=0.0) = isapprox(a, b) # for floats, use isapprox +deep_compare(a::Union{Array{AbstractFloat},Array{Complex}}, b::Union{Array{AbstractFloat},Array{Complex}}; rtol::Real=sqrt(eps()), atol::Real=0.0) = isapprox(a, b) # for arrays of floats, use isapprox element-wise +deep_compare(a::AbstractArray, b::AbstractArray; rtol::Real=sqrt(eps()), atol::Real=0.0) = all(deep_compare.(a, b; rtol, atol)) # for arrays of other types, recurse +deep_compare(a::Tuple, b::Tuple; rtol::Real=sqrt(eps()), atol::Real=0.0) = all(deep_compare.(a, b; rtol, atol)) # for tuples, recurse +deep_compare(a::T, b::T; rtol::Real=sqrt(eps()), atol::Real=0.0) where {T<:Any} = deep_compare(fieldvalues(a), fieldvalues(b); rtol, atol) # for composite types, recurse function test_from_serialized(fn::Function, filename::String) argskwargs, out = open(filename, "r") do f