Skip to content

Commit

Permalink
Change cluster_rules to use result set instead of domain for flexibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Zapiano committed Nov 19, 2024
1 parent 597ff39 commit 3140b75
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/src/usage/analysis.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ foi = ADRIA.component_params(rs, [Intervention, SeedCriteriaWeights]).fieldname

# Use SIRUS algorithm to extract rules
max_rules = 10
rules_iv = ADRIA.analysis.cluster_rules(dom, target_clusters, scens, foi, max_rules)
rules_iv = ADRIA.analysis.cluster_rules(rs, target_clusters, scens, foi, max_rules)


# Plot scatters for each rule highlighting the area selected them
Expand Down
15 changes: 8 additions & 7 deletions src/analysis/rule_extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,21 @@ function print_rules(rules::Vector{Rule{Vector{Vector},Vector{Float64}}})::Nothi
end

"""
cluster_rules(domain::Domain, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real}
cluster_rules(domain::Domain, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; kwargs...) where {T<:Int64}
cluster_rules(result_set::ADRIA.ResultSet, clusters::Vector{T}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; seed::Int64=123, kwargs...) where {T<:Integer,F<:Real}
cluster_rules(result_set::ADRIA.ResultSet, clusters::Union{BitVector,Vector{Bool}}, scenarios::DataFrame, factors::Vector{Symbol}, max_rules::T; kwargs...) where {T<:Int64}
Use SIRUS package to extract rules from time series clusters based on some summary metric
(default is median). More information about the keyword arguments accepeted can be found in
MLJ's doc (https://juliaai.github.io/MLJ.jl/dev/models/StableRulesClassifier_SIRUS/).
# Arguments
- `domain` : Domain
- `result_set` : ResultSet
- `clusters` : Vector of cluster indexes for each scenario outcome
- `scenarios` : Scenarios DataFrame
- `factors` : Vector of factors of interest
- `max_rules` : Maximum number of rules, to be used as input by SIRUS
- `seed` : Seed to be used by RGN
- `kwargs` : Keyword arguments to be passed to StableRulesClassifier
# Returns
A StableRules object (implemented by SIRUS).
Expand All @@ -132,15 +133,15 @@ A StableRules object (implemented by SIRUS).
https://doi.org//10.1214/20-EJS1792
"""
function cluster_rules(
domain::ADRIA.Domain,
result_set::ADRIA.ResultSet,
clusters::Vector{T},
scenarios::DataFrame,
factors::Vector{Symbol},
max_rules::T;
seed::Int64=123,
kwargs...
) where {T<:Int64}
ms = ADRIA.model_spec(domain)
ms = ADRIA.model_spec(result_set)
variable_factors_filter::BitVector = .!ms[ms.fieldname .∈ [factors], :is_constant]
variable_factors::Vector{Symbol} = factors[variable_factors_filter]

Expand Down Expand Up @@ -169,15 +170,15 @@ function cluster_rules(
return rules(mach.fitresult)
end
function cluster_rules(
domain::ADRIA.Domain,
result_set::ADRIA.ResultSet,
clusters::Union{BitVector,Vector{Bool}},
scenarios::DataFrame,
factors::Vector{Symbol},
max_rules::T;
kwargs...
) where {T<:Int64}
return cluster_rules(
domain, convert.(Int64, clusters), scenarios, factors, max_rules; kwargs...
result_set, convert.(Int64, clusters), scenarios, factors, max_rules; kwargs...
)
end

Expand Down
7 changes: 4 additions & 3 deletions test/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,17 @@ function test_rs_w_fig(rs::ADRIA.ResultSet, scens::ADRIA.DataFrame)
ADRIA.component_params(
rs, [Intervention, FogCriteriaWeights, SeedCriteriaWeights]
).fieldname
scenarios_iv = scens[:, fields_iv]

# Use SIRUS algorithm to extract rules
max_rules = 4
rules_iv = ADRIA.analysis.cluster_rules(target_clusters, scenarios_iv, max_rules)
rules_iv = ADRIA.analysis.cluster_rules(
rs, target_clusters, scens, fields_iv, max_rules
)

# Plot scatters for each rule highlighting the area selected them
rules_scatter_fig = ADRIA.viz.rules_scatter(
rs,
scenarios_iv,
scens,
target_clusters,
rules_iv;
fig_opts=fig_opts,
Expand Down

0 comments on commit 3140b75

Please sign in to comment.