diff --git a/src/chainrules.jl b/src/chainrules.jl index 00f3beabd..aed8b8db3 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,5 +1,5 @@ # Non Differentiable Functions -CRC.@non_differentiable replicate(::Any) +CRC.@non_differentiable replicate(::Any) # TODO: move to LuxCore.jl CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) CRC.@non_differentiable istraining(::Any) CRC.@non_differentiable _get_norm_except_dims(::Any, ::Any) @@ -7,8 +7,8 @@ CRC.@non_differentiable _affine(::Any) CRC.@non_differentiable _track_stats(::Any) CRC.@non_differentiable _conv_transpose_dims(::Any...) CRC.@non_differentiable _calc_padding(::Any...) -CRC.@non_differentiable Base.printstyled(::Any...) -CRC.@non_differentiable fieldcount(::Any) ## Type Piracy: Needs upstreaming +CRC.@non_differentiable Base.printstyled(::Any...) # TODO: Move to ChainRules.jl +CRC.@non_differentiable fieldcount(::Any) # TODO: Move to ChainRules.jl CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any) CRC.@non_differentiable __set_refval!(::Any...) CRC.@non_differentiable __state_if_stateful(::Any) diff --git a/src/forwarddiff/nested_ad.jl b/src/forwarddiff/nested_ad.jl index ac7391c6e..8ec280cda 100644 --- a/src/forwarddiff/nested_ad.jl +++ b/src/forwarddiff/nested_ad.jl @@ -31,11 +31,10 @@ for type in (:Gradient, :Jacobian) end rrule_call = if type == :Gradient - :((res, pb_f) = CRC.rrule_via_ad( - cfg, Lux.__internal_ad_gradient_call, grad_fn, f, x, y)) + :((res, pb_f) = CRC.rrule_via_ad(cfg, __internal_ad_gradient_call, grad_fn, f, x, y)) else :((res, pb_f) = CRC.rrule_via_ad( - cfg, Lux.__internal_ad_jacobian_call, ForwardDiff.$(fname), grad_fn, f, x, y)) + cfg, __internal_ad_jacobian_call, ForwardDiff.$(fname), grad_fn, f, x, y)) end ret_expr = type == :Gradient ? :(only(res)) : :(res) @eval begin diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 7701c0f5e..c7c044963 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,8 +1,10 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua, ChainRulesCore + using Aqua, ChainRulesCore, ForwardDiff Aqua.test_all(Lux; piracies=false, ambiguities=false) - Aqua.test_ambiguities(Lux; recursive=false) + Aqua.test_ambiguities(Lux; + exclude=[ForwardDiff.jacobian, ForwardDiff.gradient, + Lux.__batched_jacobian, Lux.__jacobian_vector_product_impl]) Aqua.test_piracies( Lux; treat_as_own=[ChainRulesCore.frule, ChainRulesCore.rrule, Core.kwcall]) end