mirror of
https://github.com/Cian-H/nanoconc.git
synced 2025-12-22 14:12:00 +00:00
Some DRY of tests refactoring where appropriate and a quick format/lint
This commit is contained in:
@@ -68,7 +68,7 @@ Corrects refractive index and attenuation coefficient values to account for
|
||||
the mean free-path effect on the free conductance electrons in the particles
|
||||
(translated from Haiss et als FORTRAN implementation)
|
||||
"""
|
||||
function mfp(fv::Float64, wavel::Float64, radcor::Float64, omp::Float64,
|
||||
function mfp(fv::Float64, wavel::Float64, radcor::Float64, omp::Float64,
|
||||
om0::Float64, rn::Float64, rk::Float64)::Tuple{Float64,Float64,Float64}
|
||||
|
||||
om, om0r, om_sq, _, omp_sq, om0r_sq, om_sq_plus_om0_sq =
|
||||
@@ -185,8 +185,8 @@ function bhmie(x::Float64, refrel::ComplexF64, nang::UInt32
|
||||
y::ComplexF64 = x * refrel
|
||||
nstop::UInt32 = UInt32(round(x + 4.0 * cbrt(x) + 2.0))
|
||||
nn::UInt32 = UInt32(round(max(nstop, abs(y)) + 14))
|
||||
amu::Vector{Float64} = cos.((1.570796327 / Float64(nang - 1)).*Float64.(0:nang-1))
|
||||
|
||||
amu::Vector{Float64} = cos.((1.570796327 / Float64(nang - 1)) .* Float64.(0:nang-1))
|
||||
|
||||
d = Vector{ComplexF64}(undef, nn)
|
||||
d[nn] = ComplexF64(0.0, 0.0)
|
||||
@simd for n in nn:-1:2
|
||||
@@ -273,7 +273,7 @@ end
|
||||
const _cache = Dict{UInt64,InterpolationPair}()
|
||||
|
||||
"Helper function to get the interpolation objects for rn and rk"
|
||||
function _get_rnrk_interp_objects(refcore::Array{Float64,2})::InterpolationPair
|
||||
function _get_rnrk_interp_objects(refcore::Array{Float64,2})::InterpolationPair
|
||||
refcore_hash = hash(refcore)
|
||||
if !haskey(_cache, refcore_hash)
|
||||
_cache[refcore_hash] = InterpolationPair(LinearInterpolation(refcore[:, 1], refcore[:, 2]), LinearInterpolation(refcore[:, 1], refcore[:, 3]))
|
||||
@@ -308,7 +308,7 @@ struct _mfp
|
||||
om0::Float64
|
||||
end
|
||||
|
||||
function(p::_mfp)(wavel::Float64, rn::Float64, rk::Float64)::ComplexF64
|
||||
function (p::_mfp)(wavel::Float64, rn::Float64, rk::Float64)::ComplexF64
|
||||
refre1, refim1, _ = mfp(p.fv, wavel, p.radcore, p.omp, p.om0, rn, rk)
|
||||
return ComplexF64(refre1, refim1)::ComplexF64
|
||||
end
|
||||
|
||||
@@ -269,11 +269,11 @@ struct q_to_sigma
|
||||
geometric_cross_section::Vector{Float64}
|
||||
|
||||
function q_to_sigma(diameter::Vector{Float64})::q_to_sigma
|
||||
new(π .* ((diameter ./ 2).^2))
|
||||
end
|
||||
new(π .* ((diameter ./ 2) .^ 2))
|
||||
end
|
||||
end
|
||||
|
||||
function (p::q_to_sigma)(q::Matrix{Float64})::Array{Float64, 3}
|
||||
function (p::q_to_sigma)(q::Matrix{Float64})::Array{Float64,3}
|
||||
# Calculate the extinction cross-section
|
||||
hcat([q .* x for x in p.geometric_cross_section])
|
||||
end
|
||||
|
||||
@@ -13,7 +13,7 @@ codebases=("bhmie-f.zip" "bhmie-c.zip")
|
||||
# if bhmie_dir containers bhmie-c/bhmie.so, bhmie-f/bhmie.so, and bhmie-f/bhmie_f77.so then we can skip the build
|
||||
bhmie_dir=$1
|
||||
if [ -f $bhmie_dir/bhmie-c/bhmie.so ] && [ -f $bhmie_dir/bhmie-f/bhmie.so ] && [ -f $bhmie_dir/bhmie-f/bhmie_f77.so ]; then
|
||||
echo "bhmie-c/bhmie.so, bhmie-f/bhmie.so, and bhmie-f/bhmie_f77.so already exist. Skipping build."
|
||||
echo "FFI shared objects already exist. Skipping build."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
module FFIWraps
|
||||
|
||||
include("../anchors.jl")
|
||||
|
||||
if !@isdefined TEST_DIR
|
||||
include("../anchors.jl")
|
||||
import .Anchors: TEST_DIR
|
||||
@@ -45,8 +43,8 @@ function bhmie_fortran(x::Float32, refrel::ComplexF32, nang::Int32, s1::Vector{C
|
||||
|
||||
# Call the Fortran subroutine
|
||||
ccall((:bhmie_, "$BHMIELIBS_DIR/bhmie-f/bhmie.so"), Cvoid,
|
||||
(Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}),
|
||||
x, refrel, nang, s1, s2, qext, qsca, qback, gsca)
|
||||
(Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}),
|
||||
x, refrel, nang, s1, s2, qext, qsca, qback, gsca)
|
||||
|
||||
# Return the modified values
|
||||
return qext[], qsca[], qback[], gsca[]
|
||||
@@ -61,8 +59,8 @@ function bhmie_fortran77(x::Float32, refrel::ComplexF32, nang::Int32, s1::Vector
|
||||
|
||||
# Call the Fortran subroutine
|
||||
ccall((:bhmie_, "$BHMIELIBS_DIR/bhmie-f/bhmie.so"), Cvoid,
|
||||
(Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}),
|
||||
x, refrel, nang, s1, s2, qext, qsca, qback, gsca)
|
||||
(Ref{Float32}, Ref{ComplexF32}, Ref{Int32}, Ptr{ComplexF32}, Ptr{ComplexF32}, Ref{Float32}, Ref{Float32}, Ref{Float32}, Ref{Float32}),
|
||||
x, refrel, nang, s1, s2, qext, qsca, qback, gsca)
|
||||
|
||||
# Return the modified values
|
||||
return qext[], qsca[], qback[], gsca[]
|
||||
|
||||
@@ -4,25 +4,17 @@ using PropCheck
|
||||
using Debugger
|
||||
using PyCall
|
||||
|
||||
if !@isdefined TestUtils
|
||||
include("testutils.jl")
|
||||
end
|
||||
|
||||
include("../anchors.jl")
|
||||
TestUtils.init_pyenv()
|
||||
TestUtils.singleton_include("../anchors.jl", :Anchors, @__MODULE__)
|
||||
|
||||
import .Anchors: TEST_DIR, SRC_DIR, ROOT_DIR
|
||||
|
||||
if !@isdefined TestUtils
|
||||
include(joinpath(TEST_DIR, "testutils.jl"))
|
||||
end
|
||||
if !@isdefined miemfp
|
||||
include(joinpath(SRC_DIR, "miemfp.jl"))
|
||||
end
|
||||
if !@isdefined FFIWraps
|
||||
include(joinpath(TEST_DIR, "ffi_wraps.jl"))
|
||||
end
|
||||
|
||||
|
||||
# Set up the Python environment
|
||||
run(`$ROOT_DIR/setup_venv.sh`)
|
||||
ENV["PYTHON"] = joinpath(ROOT_DIR, ".venv/bin/python")
|
||||
TestUtils.singleton_include(joinpath(SRC_DIR, "miemfp.jl"), :miemfp, @__MODULE__)
|
||||
TestUtils.singleton_include(joinpath(TEST_DIR, "ffi_wraps.jl"), :FFIWraps, @__MODULE__)
|
||||
|
||||
@pyinclude(joinpath(TEST_DIR, "miemfp_tests.py"))
|
||||
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import List, Tuple
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from hypothesis import errors, given, settings, strategies as st # type: ignore
|
||||
from typing import List, Tuple
|
||||
|
||||
def compare_bhmie_functions(f1, f2, event1: asyncio.Event, event2: asyncio.Event) -> Tuple[bool, str]:
|
||||
import numpy as np
|
||||
from hypothesis import errors, given, settings # type: ignore
|
||||
from hypothesis import strategies as st # type: ignore
|
||||
|
||||
|
||||
def compare_bhmie_functions(
|
||||
f1, f2, event1: asyncio.Event, event2: asyncio.Event
|
||||
) -> Tuple[bool, str]:
|
||||
async def async_closure(
|
||||
x: float,
|
||||
cxref: Tuple[float, float],
|
||||
@@ -13,24 +18,41 @@ def compare_bhmie_functions(f1, f2, event1: asyncio.Event, event2: asyncio.Event
|
||||
cxref = complex(*cxref)
|
||||
cxs1 = [complex(*c) for c in cxs1]
|
||||
cxs2 = [complex(*c) for c in cxs2]
|
||||
|
||||
|
||||
# This is to ensure that only one instance of each function is running at a time
|
||||
# to avoid memory issues in the FFI code
|
||||
await event1.wait()
|
||||
f1_result = f1(x, cxref, 2, cxs1, cxs2)[:2]
|
||||
await event2.wait()
|
||||
f2_result = f2(x, cxref, 2, cxs1, cxs2)[:2]
|
||||
|
||||
|
||||
return np.all(np.isclose(f1_result, f2_result))
|
||||
|
||||
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
# Must be bigger than an atom but still nanoscale
|
||||
x=st.floats(min_value=0.1, max_value=100),
|
||||
# Refractive indeces must be within a physically reasonable range
|
||||
cxref=st.tuples(st.floats(min_value=0.1, max_value=4.0), st.floats(min_value=0.1, max_value=4.0)),
|
||||
cxs1=st.lists(st.tuples(st.floats(min_value=0.1, allow_infinity=False), st.floats(min_value=0.1, allow_infinity=False)), min_size=10, max_size=100),
|
||||
cxs2=st.lists(st.tuples(st.floats(min_value=0.1, allow_infinity=False), st.floats(min_value=0.1, allow_infinity=False)), min_size=10, max_size=100),
|
||||
cxref=st.tuples(
|
||||
st.floats(min_value=0.1, max_value=4.0),
|
||||
st.floats(min_value=0.1, max_value=4.0),
|
||||
),
|
||||
cxs1=st.lists(
|
||||
st.tuples(
|
||||
st.floats(min_value=0.1, allow_infinity=False),
|
||||
st.floats(min_value=0.1, allow_infinity=False),
|
||||
),
|
||||
min_size=10,
|
||||
max_size=100,
|
||||
),
|
||||
cxs2=st.lists(
|
||||
st.tuples(
|
||||
st.floats(min_value=0.1, allow_infinity=False),
|
||||
st.floats(min_value=0.1, allow_infinity=False),
|
||||
),
|
||||
min_size=10,
|
||||
max_size=100,
|
||||
),
|
||||
)
|
||||
def sync_closure(
|
||||
x: float,
|
||||
@@ -39,11 +61,11 @@ def compare_bhmie_functions(f1, f2, event1: asyncio.Event, event2: asyncio.Event
|
||||
cxs2: List[Tuple[float, float]],
|
||||
) -> bool:
|
||||
assert asyncio.run(async_closure(x, cxref, cxs1, cxs2))
|
||||
|
||||
|
||||
try:
|
||||
sync_closure()
|
||||
return True, "Test passed"
|
||||
except AssertionError as e:
|
||||
return False, f"AssertionError: {str(e)}"
|
||||
except errors.HypothesisException as e:
|
||||
return False, f"HypothesisException: {str(e)}"
|
||||
return False, f"HypothesisException: {str(e)}"
|
||||
|
||||
@@ -3,9 +3,12 @@ using Test
|
||||
include("testutils.jl")
|
||||
include("../src/nanoconc.jl")
|
||||
|
||||
# Set up the Python environment
|
||||
TestUtils.init_pyenv()
|
||||
|
||||
# include("nanoconc_tests.jl")
|
||||
include("miemfp_tests.jl")
|
||||
TestUtils.singleton_include("miemfp_tests.jl", :miemfp, @__MODULE__)
|
||||
# include("quantumcalc_tests.jl")
|
||||
include("benchmarks.jl")
|
||||
TestUtils.singleton_include("benchmarks.jl", :Benchmarks, @__MODULE__)
|
||||
|
||||
Benchmarks.bench_vs_ffi()
|
||||
@@ -3,16 +3,39 @@ using Serialization
|
||||
|
||||
using Test
|
||||
|
||||
function singleton_include(filepath::String, m::Symbol, calling_module::Module)
|
||||
try
|
||||
target = getproperty(calling_module, m)
|
||||
if !isdefined(target, :Module)
|
||||
raise("Module $m not defined in $calling_module")
|
||||
end
|
||||
catch
|
||||
Base.include(calling_module, filepath)
|
||||
end
|
||||
end
|
||||
|
||||
singleton_include("../anchors.jl", :Anchors, @__MODULE__)
|
||||
|
||||
import .Anchors: TEST_DIR, SRC_DIR, ROOT_DIR
|
||||
|
||||
function init_pyenv()
|
||||
if "PYTHON" in keys(ENV) && ENV["PYTHON"] != joinpath(ROOT_DIR, ".venv/bin/python")
|
||||
run(`$ROOT_DIR/setup_venv.sh`)
|
||||
ENV["PYTHON"] = joinpath(ROOT_DIR, ".venv/bin/python")
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function fieldvalues(obj)
|
||||
[getfield(obj, f) for f in fieldnames(typeof(obj))]
|
||||
end
|
||||
|
||||
deep_compare(a, b; rtol::Real=sqrt(eps()), atol::Real=0.) = a == b # base case: use ==
|
||||
deep_compare(a::Union{AbstractFloat, Complex}, b::Union{AbstractFloat, Complex}; rtol::Real=sqrt(eps()), atol::Real=0.) = isapprox(a, b) # for floats, use isapprox
|
||||
deep_compare(a::Union{Array{AbstractFloat}, Array{Complex}}, b::Union{Array{AbstractFloat}, Array{Complex}}; rtol::Real=sqrt(eps()), atol::Real=0.) = isapprox(a, b) # for arrays of floats, use isapprox element-wise
|
||||
deep_compare(a::AbstractArray, b::AbstractArray; rtol::Real=sqrt(eps()), atol::Real=0.) = all(deep_compare.(a, b; rtol, atol)) # for arrays of other types, recurse
|
||||
deep_compare(a::Tuple, b::Tuple; rtol::Real=sqrt(eps()), atol::Real=0.) = all(deep_compare.(a, b; rtol, atol)) # for tuples, recurse
|
||||
deep_compare(a::T, b::T; rtol::Real=sqrt(eps()), atol::Real=0.) where {T <: Any} = deep_compare(fieldvalues(a), fieldvalues(b); rtol, atol) # for composite types, recurse
|
||||
deep_compare(a, b; rtol::Real=sqrt(eps()), atol::Real=0.0) = a == b # base case: use ==
|
||||
deep_compare(a::Union{AbstractFloat,Complex}, b::Union{AbstractFloat,Complex}; rtol::Real=sqrt(eps()), atol::Real=0.0) = isapprox(a, b) # for floats, use isapprox
|
||||
deep_compare(a::Union{Array{AbstractFloat},Array{Complex}}, b::Union{Array{AbstractFloat},Array{Complex}}; rtol::Real=sqrt(eps()), atol::Real=0.0) = isapprox(a, b) # for arrays of floats, use isapprox element-wise
|
||||
deep_compare(a::AbstractArray, b::AbstractArray; rtol::Real=sqrt(eps()), atol::Real=0.0) = all(deep_compare.(a, b; rtol, atol)) # for arrays of other types, recurse
|
||||
deep_compare(a::Tuple, b::Tuple; rtol::Real=sqrt(eps()), atol::Real=0.0) = all(deep_compare.(a, b; rtol, atol)) # for tuples, recurse
|
||||
deep_compare(a::T, b::T; rtol::Real=sqrt(eps()), atol::Real=0.0) where {T<:Any} = deep_compare(fieldvalues(a), fieldvalues(b); rtol, atol) # for composite types, recurse
|
||||
|
||||
function test_from_serialized(fn::Function, filename::String)
|
||||
argskwargs, out = open(filename, "r") do f
|
||||
|
||||
Reference in New Issue
Block a user