From 3140b757d3daad3b442771e20eca111c3d136516 Mon Sep 17 00:00:00 2001 From: Pedro Ribeiro de Almeida Date: Fri, 1 Nov 2024 15:32:34 +1100 Subject: [PATCH] Change cluster_rules to use result set instead of domain for flexibility --- docs/src/usage/analysis.md | 2 +- src/analysis/rule_extraction.jl | 15 ++++++++------- test/analysis.jl | 7 ++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/src/usage/analysis.md b/docs/src/usage/analysis.md index 308908f55..ae80b25ff 100644 --- a/docs/src/usage/analysis.md +++ b/docs/src/usage/analysis.md @@ -410,7 +410,7 @@ foi = ADRIA.component_params(rs, [Intervention, SeedCriteriaWeights]).fieldname # Use SIRUS algorithm to extract rules max_rules = 10 -rules_iv = ADRIA.analysis.cluster_rules(dom, target_clusters, scens, foi, max_rules) +rules_iv = ADRIA.analysis.cluster_rules(rs, target_clusters, scens, foi, max_rules) # Plot scatters for each rule highlighting the area selected them diff --git a/src/analysis/rule_extraction.jl b/src/analysis/rule_extraction.jl index 895d74d44..9bda85673 100644 --- a/src/analysis/rule_extraction.jl +++ b/src/analysis/rule_extraction.jl @@ -103,20 +103,21 @@ function print_rules(rules::Vector{Rule{Vector{Vector},Vector{Float64}}})::Nothi end """ - cluster_rules(domain::Domain, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real} - cluster_rules(domain::Domain, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; kwargs...) where {T<:Int64} + cluster_rules(result_set::ADRIA.ResultSet, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real} + cluster_rules(result_set::ADRIA.ResultSet, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; kwargs...) where {T<:Int64} Use SIRUS package to extract rules from time series clusters based on some summary metric (default is median). More information about the keyword arguments accepeted can be found in MLJ's doc (https://juliaai.github.io/MLJ.jl/dev/models/StableRulesClassifier_SIRUS/). # Arguments -- `domain` : Domain +- `result_set` : ResultSet - `clusters` : Vector of cluster indexes for each scenario outcome - `scenarios` : Scenarios DataFrame - `factors` : Vector of factors of interest - `max_rules` : Maximum number of rules, to be used as input by SIRUS - `seed` : Seed to be used by RGN +- `kwargs` : Keyword arguments to be passed to StableRulesClassifier # Returns A StableRules object (implemented by SIRUS). @@ -132,7 +133,7 @@ A StableRules object (implemented by SIRUS). https://doi.org//10.1214/20-EJS1792 """ function cluster_rules( - domain::ADRIA.Domain, + result_set::ADRIA.ResultSet, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, @@ -140,7 +141,7 @@ function cluster_rules( seed::Int64=123, kwargs... ) where {T<:Int64} - ms = ADRIA.model_spec(domain) + ms = ADRIA.model_spec(result_set) variable_factors_filter::BitVector = .!ms[ms.fieldname .∈ [factors], :is_constant] variable_factors::Vector{Symbol} = factors[variable_factors_filter] @@ -169,7 +170,7 @@ function cluster_rules( return rules(mach.fitresult) end function cluster_rules( - domain::ADRIA.Domain, + result_set::ADRIA.ResultSet, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, @@ -177,7 +178,7 @@ function cluster_rules( kwargs... ) where {T<:Int64} return cluster_rules( - domain, convert.(Int64, clusters), scenarios, factors, max_rules; kwargs... + result_set, convert.(Int64, clusters), scenarios, factors, max_rules; kwargs... ) end diff --git a/test/analysis.jl b/test/analysis.jl index e8d7aae5e..ba7cb219c 100644 --- a/test/analysis.jl +++ b/test/analysis.jl @@ -243,16 +243,17 @@ function test_rs_w_fig(rs::ADRIA.ResultSet, scens::ADRIA.DataFrame) ADRIA.component_params( rs, [Intervention, FogCriteriaWeights, SeedCriteriaWeights] ).fieldname - scenarios_iv = scens[:, fields_iv] # Use SIRUS algorithm to extract rules max_rules = 4 - rules_iv = ADRIA.analysis.cluster_rules(target_clusters, scenarios_iv, max_rules) + rules_iv = ADRIA.analysis.cluster_rules( + rs, target_clusters, scens, fields_iv, max_rules + ) # Plot scatters for each rule highlighting the area selected them rules_scatter_fig = ADRIA.viz.rules_scatter( rs, - scenarios_iv, + scens, target_clusters, rules_iv; fig_opts=fig_opts,