diff --git a/docs/src/usage/analysis.md b/docs/src/usage/analysis.md index 308908f55..ae80b25ff 100644 --- a/docs/src/usage/analysis.md +++ b/docs/src/usage/analysis.md @@ -410,7 +410,7 @@ foi = ADRIA.component_params(rs, [Intervention, SeedCriteriaWeights]).fieldname # Use SIRUS algorithm to extract rules max_rules = 10 -rules_iv = ADRIA.analysis.cluster_rules(dom, target_clusters, scens, foi, max_rules) +rules_iv = ADRIA.analysis.cluster_rules(rs, target_clusters, scens, foi, max_rules) # Plot scatters for each rule highlighting the area selected them diff --git a/src/analysis/rule_extraction.jl b/src/analysis/rule_extraction.jl index 895d74d44..9bda85673 100644 --- a/src/analysis/rule_extraction.jl +++ b/src/analysis/rule_extraction.jl @@ -103,20 +103,21 @@ function print_rules(rules::Vector{Rule{Vector{Vector},Vector{Float64}}})::Nothi end """ - cluster_rules(domain::Domain, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real} - cluster_rules(domain::Domain, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; kwargs...) where {T<:Int64} + cluster_rules(result_set::ADRIA.ResultSet, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real} + cluster_rules(result_set::ADRIA.ResultSet, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; kwargs...) where {T<:Int64} Use SIRUS package to extract rules from time series clusters based on some summary metric (default is median). More information about the keyword arguments accepeted can be found in MLJ's doc (https://juliaai.github.io/MLJ.jl/dev/models/StableRulesClassifier_SIRUS/). # Arguments -- `domain` : Domain +- `result_set` : ResultSet - `clusters` : Vector of cluster indexes for each scenario outcome - `scenarios` : Scenarios DataFrame - `factors` : Vector of factors of interest - `max_rules` : Maximum number of rules, to be used as input by SIRUS - `seed` : Seed to be used by RGN +- `kwargs` : Keyword arguments to be passed to StableRulesClassifier # Returns A StableRules object (implemented by SIRUS). @@ -132,7 +133,7 @@ A StableRules object (implemented by SIRUS). https://doi.org//10.1214/20-EJS1792 """ function cluster_rules( - domain::ADRIA.Domain, + result_set::ADRIA.ResultSet, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, @@ -140,7 +141,7 @@ function cluster_rules( seed::Int64=123, kwargs... ) where {T<:Int64} - ms = ADRIA.model_spec(domain) + ms = ADRIA.model_spec(result_set) variable_factors_filter::BitVector = .!ms[ms.fieldname .∈ [factors], :is_constant] variable_factors::Vector{Symbol} = factors[variable_factors_filter] @@ -169,7 +170,7 @@ function cluster_rules( return rules(mach.fitresult) end function cluster_rules( - domain::ADRIA.Domain, + result_set::ADRIA.ResultSet, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, @@ -177,7 +178,7 @@ function cluster_rules( kwargs... ) where {T<:Int64} return cluster_rules( - domain, convert.(Int64, clusters), scenarios, factors, max_rules; kwargs... + result_set, convert.(Int64, clusters), scenarios, factors, max_rules; kwargs... ) end diff --git a/test/analysis.jl b/test/analysis.jl index e8d7aae5e..ba7cb219c 100644 --- a/test/analysis.jl +++ b/test/analysis.jl @@ -243,16 +243,17 @@ function test_rs_w_fig(rs::ADRIA.ResultSet, scens::ADRIA.DataFrame) ADRIA.component_params( rs, [Intervention, FogCriteriaWeights, SeedCriteriaWeights] ).fieldname - scenarios_iv = scens[:, fields_iv] # Use SIRUS algorithm to extract rules max_rules = 4 - rules_iv = ADRIA.analysis.cluster_rules(target_clusters, scenarios_iv, max_rules) + rules_iv = ADRIA.analysis.cluster_rules( + rs, target_clusters, scens, fields_iv, max_rules + ) # Plot scatters for each rule highlighting the area selected them rules_scatter_fig = ADRIA.viz.rules_scatter( rs, - scenarios_iv, + scens, target_clusters, rules_iv; fig_opts=fig_opts,