diff --git a/dev/.documenter-siteinfo.json b/dev/.documenter-siteinfo.json index 0acb996a..c5d1f90f 100644 --- a/dev/.documenter-siteinfo.json +++ b/dev/.documenter-siteinfo.json @@ -1 +1 @@ -{"documenter":{"julia_version":"1.6.7","generation_timestamp":"2024-03-05T17:35:26","documenter_version":"1.3.0"}} \ No newline at end of file +{"documenter":{"julia_version":"1.6.7","generation_timestamp":"2024-03-05T20:33:58","documenter_version":"1.3.0"}} \ No newline at end of file diff --git a/dev/about/index.html b/dev/about/index.html index f6f09b16..f1263e5a 100644 --- a/dev/about/index.html +++ b/dev/about/index.html @@ -1,2 +1,2 @@ -About · Imbalance.jl

Credits

This package was created by Essam Wisam as a Google Summer of Code project, under the mentorship of Anthony Blaom. Special thanks also go to Rik Huijzer for his friendliness and the binary SMOTE implementation in Resample.jl.

+About · Imbalance.jl

Credits

This package was created by Essam Wisam as a Google Summer of Code project, under the mentorship of Anthony Blaom. Special thanks also go to Rik Huijzer for his friendliness and the binary SMOTE implementation in Resample.jl.

diff --git a/dev/algorithms/extra_algorithms/index.html b/dev/algorithms/extra_algorithms/index.html index adda52ce..705fc5dc 100644 --- a/dev/algorithms/extra_algorithms/index.html +++ b/dev/algorithms/extra_algorithms/index.html @@ -72,4 +72,4 @@ julia> Imbalance.checkbalance(y; ref="minority") 1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 10034 (100.0%) -0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 39966 (398.3%) source +0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 39966 (398.3%) source diff --git a/dev/algorithms/implementation_notes/index.html b/dev/algorithms/implementation_notes/index.html index f448d5e5..e30f53f7 100644 --- a/dev/algorithms/implementation_notes/index.html +++ b/dev/algorithms/implementation_notes/index.html @@ -1,2 +1,2 @@ -Implementation Notes · Imbalance.jl

Generalizing to Multiclass

Papers often propose the resampling algorithm for the case of binary classification only. In many cases, the algorithm only expects a set of points to resample and has nothing to do with the existence of a majority class (e.g., estimates the distribution of points then generates new samples from it) so it can be generalized by simply applying the algorithm on each class. In other cases, there is an interaction with the majority class (e.g., a point is borderline in BorderlineSMOTE1 if the majority but not all its neighbors are from the majority class). In this case, a one-vs-rest scheme is used as proposed in [1]. For instance, a point is now borderline if most but not all its neighbors are from a different class.

Generalizing to Real Ratios

Papers often proposes the resampling algorithm using integer ratios. For instance, a ratio of 2 would mean to double the amount of data in a class and a ratio of $2.2$ is not allowed or will be rounded. In Imbalance.jl any appropriate real ratio can be used and the ratio is relative to the size of the majority or minority class depending on whether the algorithm is oversampling or undersampling. The generalization occurs by randomly choosing points instead of looping on each point. That is, if a $2.2$ ratio corresponds to $227$ examples then $227$ examples are chosen randomly by replacement then applying resampling logic to each. Given an integer ratio $k$, this falls back to be on average equivalent to looping on the points $k$ times.

[1] Fernández, A., López, V., Galar, M., Del Jesus, M. J., and Herrera, F. (2013). Analysing the classifi- cation of imbalanced data-sets with multiple classes: Binarization techniques and ad-hoc approaches. Knowledge-Based Systems, 42:97–110.

+Implementation Notes · Imbalance.jl

Generalizing to Multiclass

Papers often propose the resampling algorithm for the case of binary classification only. In many cases, the algorithm only expects a set of points to resample and has nothing to do with the existence of a majority class (e.g., estimates the distribution of points then generates new samples from it) so it can be generalized by simply applying the algorithm on each class. In other cases, there is an interaction with the majority class (e.g., a point is borderline in BorderlineSMOTE1 if the majority but not all its neighbors are from the majority class). In this case, a one-vs-rest scheme is used as proposed in [1]. For instance, a point is now borderline if most but not all its neighbors are from a different class.

Generalizing to Real Ratios

Papers often proposes the resampling algorithm using integer ratios. For instance, a ratio of 2 would mean to double the amount of data in a class and a ratio of $2.2$ is not allowed or will be rounded. In Imbalance.jl any appropriate real ratio can be used and the ratio is relative to the size of the majority or minority class depending on whether the algorithm is oversampling or undersampling. The generalization occurs by randomly choosing points instead of looping on each point. That is, if a $2.2$ ratio corresponds to $227$ examples then $227$ examples are chosen randomly by replacement then applying resampling logic to each. Given an integer ratio $k$, this falls back to be on average equivalent to looping on the points $k$ times.

[1] Fernández, A., López, V., Galar, M., Del Jesus, M. J., and Herrera, F. (2013). Analysing the classifi- cation of imbalanced data-sets with multiple classes: Binarization techniques and ad-hoc approaches. Knowledge-Based Systems, 42:97–110.

diff --git a/dev/algorithms/mlj_balancing/index.html b/dev/algorithms/mlj_balancing/index.html index c7f9f53f..80c19f73 100644 --- a/dev/algorithms/mlj_balancing/index.html +++ b/dev/algorithms/mlj_balancing/index.html @@ -20,4 +20,4 @@ logistic_model = LogisticClassifier() bagging_model = BalancedBaggingClassifier(model=logistic_model, T=10, rng=Random.Xoshiro(42))

Now you can fit, predict, cross-validate and finetune it like any other probabilistic MLJ model where X must be a table input (e.g., a dataframe).

mach = machine(bagging_model, X, y)
 fit!(mach)
-pred = predict(mach, X)
+pred = predict(mach, X) diff --git a/dev/algorithms/oversampling_algorithms/index.html b/dev/algorithms/oversampling_algorithms/index.html index 2dc2b7f2..e3ca952b 100644 --- a/dev/algorithms/oversampling_algorithms/index.html +++ b/dev/algorithms/oversampling_algorithms/index.html @@ -373,4 +373,4 @@ oversampler = SMOTENC(y_ind; k=5, ratios=Dict(1=>1.0, 2=> 0.9, 3=>0.9), rng=42) Xyover = Xy |> oversampler # equivalently if TableTransforms is used -Xyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently

Illustration

A full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.

References

[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.

source +Xyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently

Illustration

A full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.

References

[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.

source diff --git a/dev/algorithms/undersampling_algorithms/index.html b/dev/algorithms/undersampling_algorithms/index.html index 8c9d15e6..b2ee204b 100644 --- a/dev/algorithms/undersampling_algorithms/index.html +++ b/dev/algorithms/undersampling_algorithms/index.html @@ -185,4 +185,4 @@ # Initiate TomekUndersampler model undersampler = TomekUndersampler(y_ind; min_ratios=1.0, rng=42) Xy_under = Xy |> undersampler -Xy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently

The reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.

Illustration

A full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.

References

[1] Ivan Tomek. Two modifications of cnn. IEEE Trans. Systems, Man and Cybernetics, 6:769–772, 1976.

source +Xy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently

The reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.

Illustration

A full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.

References

[1] Ivan Tomek. Two modifications of cnn. IEEE Trans. Systems, Man and Cybernetics, 6:769–772, 1976.

source diff --git a/dev/contributing/index.html b/dev/contributing/index.html index d41a6390..c1fb280b 100644 --- a/dev/contributing/index.html +++ b/dev/contributing/index.html @@ -12,4 +12,4 @@ └── extras.jl # extra functions like generating data or checking balance

The purpose of each file is further documented therein at the beginning of the file. The files are ordered here in the recommended order of checking.

Any method resampling method implemented in the oversampling_methods or undersampling_methods folder takes the following structure:

├── resample_method          # contains implementation and interfaces for a resampling method
 │   ├── interface_mlj.jl     # implements MLJ interface for the method
 │   ├── interface_tables.jl  # implements Tables.jl interface for the method
-│   └── resample_method.jl   # implements the method itself (pure functional interface)

Contribution

Reporting Problems or Seeking Support

Adding New Resampling Methods

Surely, you can ignore ignore the third step if the algorithm you are implementing does not operate in "per-class" sense.

🔥 Hot algorithms to add

Adding New Tutorials

+│ └── resample_method.jl # implements the method itself (pure functional interface)

Contribution

Reporting Problems or Seeking Support

Adding New Resampling Methods

Surely, you can ignore ignore the third step if the algorithm you are implementing does not operate in "per-class" sense.

🔥 Hot algorithms to add

Adding New Tutorials

diff --git a/dev/examples/Colab/index.html b/dev/examples/Colab/index.html index 40646db5..d341e3c9 100644 --- a/dev/examples/Colab/index.html +++ b/dev/examples/Colab/index.html @@ -9,4 +9,4 @@ rm /tmp/julia.tar.gz fi julia -e 'using Pkg; pkg"add IJulia; precompile;"' -echo 'Done'

Sincere thanks to Julia-on-Colab for making this possible.

+echo 'Done'

Sincere thanks to Julia-on-Colab for making this possible.

diff --git a/dev/examples/cerebral_ensemble/cerebral_ensemble/index.html b/dev/examples/cerebral_ensemble/cerebral_ensemble/index.html index 420ddbf0..f2cd1b60 100644 --- a/dev/examples/cerebral_ensemble/cerebral_ensemble/index.html +++ b/dev/examples/cerebral_ensemble/cerebral_ensemble/index.html @@ -194,4 +194,4 @@ │ BalancedAccuracy( │ predict_mode │ 0.772 │ 0.0146 │ [0.738, 0.769, ⋯ │ adjusted = false) │ │ │ │ ⋯ └─────────────────────┴──────────────┴─────────────┴─────────┴────────────────── - 1 column omitted

Under the normality of scores, the 95% confidence interval is 77.2±1.4% for the balanced accuracy.

+ 1 column omitted

Under the normality of scores, the 95% confidence interval is 77.2±1.4% for the balanced accuracy.

diff --git a/dev/examples/effect_of_k_enn/effect_of_k_enn/index.html b/dev/examples/effect_of_k_enn/effect_of_k_enn/index.html index 3bf82863..2f09e069 100644 --- a/dev/examples/effect_of_k_enn/effect_of_k_enn/index.html +++ b/dev/examples/effect_of_k_enn/effect_of_k_enn/index.html @@ -272,4 +272,4 @@ end plot!(dpi = 150) end -
gif(anim, "./assets/enn-k-animation.gif", fps=1)

enn-gif-hyperparameter

As we can see, the most constraining condition is all. It deletes any point where the label is different than any of the nearest k neighbors which also explains why it's the most sensitive to the hyperparameter k.

+
gif(anim, "./assets/enn-k-animation.gif", fps=1)

enn-gif-hyperparameter

As we can see, the most constraining condition is all. It deletes any point where the label is different than any of the nearest k neighbors which also explains why it's the most sensitive to the hyperparameter k.

diff --git a/dev/examples/effect_of_ratios/effect_of_ratios/index.html b/dev/examples/effect_of_ratios/effect_of_ratios/index.html index 62a92c63..e468fb65 100644 --- a/dev/examples/effect_of_ratios/effect_of_ratios/index.html +++ b/dev/examples/effect_of_ratios/effect_of_ratios/index.html @@ -213,4 +213,4 @@ plot!(dpi = 150) end
gif(anim, "./assets/smote-animation.gif", fps=6)
-println()

Ratios Parameter Effect

Notice how setting ratios greedily can lead to overfitting.

+println()

Ratios Parameter Effect

Notice how setting ratios greedily can lead to overfitting.

diff --git a/dev/examples/effect_of_s/effect_of_s/index.html b/dev/examples/effect_of_s/effect_of_s/index.html index 881e429b..452b6ed0 100644 --- a/dev/examples/effect_of_s/effect_of_s/index.html +++ b/dev/examples/effect_of_s/effect_of_s/index.html @@ -203,4 +203,4 @@ plot!(dpi = 150) end
gif(anim, "./assets/rose-animation.gif", fps=6)
-println()

ROSE Effect of S

As we can see, the larger s is the more spread out are the oversampled points. This is expected because what ROSE does is oversample by sampling from the distribution that corresponds to placing Gaussians on the existing points and s is a hyperparameter proportional to the bandwidth of the Gaussians. When s=0 the only points that can be generated lie on top of others; i.e., ROSE becomes equivalent to random oversampling

The decision boundary is mainly unstable because we used a small number of epochs with the perceptron to generate this animation. It still took plenty of time.

+println()

ROSE Effect of S

As we can see, the larger s is the more spread out are the oversampled points. This is expected because what ROSE does is oversample by sampling from the distribution that corresponds to placing Gaussians on the existing points and s is a hyperparameter proportional to the bandwidth of the Gaussians. When s=0 the only points that can be generated lie on top of others; i.e., ROSE becomes equivalent to random oversampling

The decision boundary is mainly unstable because we used a small number of epochs with the perceptron to generate this animation. It still took plenty of time.

diff --git a/dev/examples/fraud_detection/fraud_detection/index.html b/dev/examples/fraud_detection/fraud_detection/index.html index 95290125..d616fefc 100644 --- a/dev/examples/fraud_detection/fraud_detection/index.html +++ b/dev/examples/fraud_detection/fraud_detection/index.html @@ -144,4 +144,4 @@ │ BalancedAccuracy( │ predict_mode │ 0.908 │ 0.00932 │ [0.903, 0.898, ⋯ │ adjusted = false) │ │ │ │ ⋯ └─────────────────────┴──────────────┴─────────────┴─────────┴────────────────── - 1 column omitted

Assuming normal scores, the 95% confidence interval was 90.8±0.9 and after resampling it has become 93±0.7 which corresponds to a small improvement in accuracy.

+ 1 column omitted

Assuming normal scores, the 95% confidence interval was 90.8±0.9 and after resampling it has become 93±0.7 which corresponds to a small improvement in accuracy.

diff --git a/dev/examples/index.html b/dev/examples/index.html index 7d3e3f8d..b7538644 100644 --- a/dev/examples/index.html +++ b/dev/examples/index.html @@ -88,4 +88,4 @@ - + diff --git a/dev/examples/smote_churn_dataset/smote_churn_dataset/index.html b/dev/examples/smote_churn_dataset/smote_churn_dataset/index.html index 7918ce94..f4cc028e 100644 --- a/dev/examples/smote_churn_dataset/smote_churn_dataset/index.html +++ b/dev/examples/smote_churn_dataset/smote_churn_dataset/index.html @@ -163,4 +163,4 @@ │ BalancedAccuracy( │ predict_mode │ 0.552 │ 0.0145 │ [0.549, 0.563, ⋯ │ adjusted = false) │ │ │ │ ⋯ └─────────────────────┴──────────────┴─────────────┴─────────┴────────────────── - 1 column omitted

The improvement is about 5.2% after cross-validation. If we are further to assume scores to be normally distributed, then the 95% confidence interval is 5.2±1.45% improvement. Let's see if this gets any better when we rather use SMOTE-NC in a later example.

+ 1 column omitted

The improvement is about 5.2% after cross-validation. If we are further to assume scores to be normally distributed, then the 95% confidence interval is 5.2±1.45% improvement. Let's see if this gets any better when we rather use SMOTE-NC in a later example.

diff --git a/dev/examples/smoten_mushroom/smoten_mushroom/index.html b/dev/examples/smoten_mushroom/smoten_mushroom/index.html index 5658c026..a631ccc6 100644 --- a/dev/examples/smoten_mushroom/smoten_mushroom/index.html +++ b/dev/examples/smoten_mushroom/smoten_mushroom/index.html @@ -253,4 +253,4 @@ │ BalancedAccuracy( │ predict │ 0.4 │ 0.00483 │ [0.398, 0.405, 0.3 ⋯ │ adjusted = false) │ │ │ │ ⋯ └─────────────────────┴───────────┴─────────────┴─────────┴───────────────────── - 1 column omitted

Fair enough. After oversampling the interval under the same assumptions is 40±0.5%; this agrees with our earlier observations using simple point estimates; oversampling here approximately delivers a 18% improvement in balanced accuracy.

+ 1 column omitted

Fair enough. After oversampling the interval under the same assumptions is 40±0.5%; this agrees with our earlier observations using simple point estimates; oversampling here approximately delivers a 18% improvement in balanced accuracy.

diff --git a/dev/examples/smotenc_churn_dataset/smotenc_churn_dataset/index.html b/dev/examples/smotenc_churn_dataset/smotenc_churn_dataset/index.html index 5c252769..ec895683 100644 --- a/dev/examples/smotenc_churn_dataset/smotenc_churn_dataset/index.html +++ b/dev/examples/smotenc_churn_dataset/smotenc_churn_dataset/index.html @@ -190,4 +190,4 @@ │ BalancedAccuracy( │ predict_mode │ 0.677 │ 0.0124 │ [0.678, 0.688, ⋯ │ adjusted = false) │ │ │ │ ⋯ └─────────────────────┴──────────────┴─────────────┴─────────┴────────────────── - 1 column omitted

Fair enough. After oversampling the interval under the same assumptions is 67.7±1.2% which is still a meaningful improvement over 56.5±0.62 that we had prior to oversampling ot the 55.2±1.5% that we had with logistic regression and SMOTE in an earlier example.

+ 1 column omitted

Fair enough. After oversampling the interval under the same assumptions is 67.7±1.2% which is still a meaningful improvement over 56.5±0.62 that we had prior to oversampling ot the 55.2±1.5% that we had with logistic regression and SMOTE in an earlier example.

diff --git a/dev/examples/walkthrough/index.html b/dev/examples/walkthrough/index.html index 979b400c..3754fa78 100644 --- a/dev/examples/walkthrough/index.html +++ b/dev/examples/walkthrough/index.html @@ -239,4 +239,4 @@ │ BalancedAccuracy( │ predict_mode │ 0.7 │ 0.0717 │ [0.7, 0.536, 0. ⋯ │ adjusted = false) │ │ │ │ ⋯ └─────────────────────┴──────────────┴─────────────┴─────────┴────────────────── - 1 column omitted

This results in an interval 70±7.2% which can be viewed as a reasonable improvement over 62.1±9.13%. The uncertainty in the intervals can be explained by the fact that the dataset is small with many classes.

+ 1 column omitted

This results in an interval 70±7.2% which can be viewed as a reasonable improvement over 62.1±9.13%. The uncertainty in the intervals can be explained by the fact that the dataset is small with many classes.

diff --git a/dev/index.html b/dev/index.html index 149fb51a..07bfe1f0 100644 --- a/dev/index.html +++ b/dev/index.html @@ -8,7 +8,7 @@ X, y = generate_imbalanced_data(num_rows, num_continuous_feats; class_probs, rng=42)

In following code blocks, it will be assumed that X and y are readily available.

Standard API

All methods by default support a pure functional interface.

using Imbalance
 
 # Apply SMOTE to oversample the classes
-Xover, yover = smote(X, y; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)

In following code blocks, it will be assumed that X and y are readily available.

MLJ Interface

All methods support the MLJ interface where instead of directly calling the method, one instantiates a model for the method while optionally passing the keyword parameters found in the functional interface then wraps the model in a machine and follows by calling transform on the machine and data.

using MLJ
+Xover, yover = smote(X, y; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)

MLJ Interface

All methods support the MLJ interface where instead of directly calling the method, one instantiates a model for the method while optionally passing the keyword parameters found in the functional interface then wraps the model in a machine and follows by calling transform on the machine and data.

using MLJ
 
 # Load the model
 SMOTE = @load SMOTE pkg=Imbalance
@@ -53,4 +53,4 @@
 oversampler = SMOTE(y_ind; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)
 Xyover = Xy |> oversampler       # can chain with other table transforms                  
 # equivalently if TableTransforms is used
-Xyover, cache = TableTransforms.apply(oversampler, Xy)    # equivalently

The reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.

Notice that because the interfaces of MLJ and TableTransforms use the same model names, you will have to specify the source of the model if both are used in the same file (e.g., Imbalance.TableTransforms.SMOTE) for the example above.

Features

Rationale

Most if not all machine learning algorithms can be viewed as a form of empirical risk minimization where the object is to find the parameters $\theta$ that for some loss function $L$ minimize

\[\hat{\theta} = \arg\min_{\theta} \frac{1}{N} \sum_{i=1}^{N} L(f_{\theta}(x_i), y_i)\]

The underlying assumption is that minimizing this empirical risk corresponds to approximately minimizing the true risk which considers all examples in the populations which would imply that $f_\theta$ is approximately the true target function $f$ that we seek to model.

In a multi-class setting with $K$ classes, one can write

\[\hat{\theta} = \arg\min_{\theta} \left( \frac{1}{N_1} \sum_{i \in C_1} L(f_{\theta}(x_i), y_i) + \frac{1}{N_2} \sum_{i \in C_2} L(f_{\theta}(x_i), y_i) + \ldots + \frac{1}{N_K} \sum_{i \in C_K} L(f_{\theta}(x_i), y_i) \right)\]

Class imbalance occurs when some classes have much fewer examples than other classes. In this case, the terms corresponding to smaller classes contribute minimally to the sum which makes it possible for any learning algorithm to find an approximate solution to minimizing the empirical risk that mostly only minimizes the over the significant sums. This yields a hypothesis $f_\theta$ that may be very different from the true target $f$ with respect to the minority classes which may be the most important for the application in question.

One obvious possible remedy is to weight the smaller sums so that a learning algorithm more easily avoids approximate solutions that exploit their insignificance which can be seen to be equivalent to repeating examples of the observations in minority classes. This can be achieved by naive random oversampling which is offered by this package along with other more advanced oversampling methods that function by generating synthetic data or deleting existing ones. You can read more about the class imbalance problem and learn about various algorithms implemented in this package by reading this series of articles on Medium.

To our knowledge, there are no existing maintained Julia packages that implement resampling algorithms for multi-class classification problems or that handle both nominal and continuous features. This has served as a primary motivation for the creation of this package.

+Xyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently

The reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.

Notice that because the interfaces of MLJ and TableTransforms use the same model names, you will have to specify the source of the model if both are used in the same file (e.g., Imbalance.TableTransforms.SMOTE) for the example above.

Features

Rationale

Most if not all machine learning algorithms can be viewed as a form of empirical risk minimization where the object is to find the parameters $\theta$ that for some loss function $L$ minimize

\[\hat{\theta} = \arg\min_{\theta} \frac{1}{N} \sum_{i=1}^{N} L(f_{\theta}(x_i), y_i)\]

The underlying assumption is that minimizing this empirical risk corresponds to approximately minimizing the true risk which considers all examples in the populations which would imply that $f_\theta$ is approximately the true target function $f$ that we seek to model.

In a multi-class setting with $K$ classes, one can write

\[\hat{\theta} = \arg\min_{\theta} \left( \frac{1}{N_1} \sum_{i \in C_1} L(f_{\theta}(x_i), y_i) + \frac{1}{N_2} \sum_{i \in C_2} L(f_{\theta}(x_i), y_i) + \ldots + \frac{1}{N_K} \sum_{i \in C_K} L(f_{\theta}(x_i), y_i) \right)\]

Class imbalance occurs when some classes have much fewer examples than other classes. In this case, the terms corresponding to smaller classes contribute minimally to the sum which makes it possible for any learning algorithm to find an approximate solution to minimizing the empirical risk that mostly only minimizes the over the significant sums. This yields a hypothesis $f_\theta$ that may be very different from the true target $f$ with respect to the minority classes which may be the most important for the application in question.

One obvious possible remedy is to weight the smaller sums so that a learning algorithm more easily avoids approximate solutions that exploit their insignificance which can be seen to be equivalent to repeating examples of the observations in minority classes. This can be achieved by naive random oversampling which is offered by this package along with other more advanced oversampling methods that function by generating synthetic data or deleting existing ones. You can read more about the class imbalance problem and learn about various algorithms implemented in this package by reading this series of articles on Medium.

To our knowledge, there are no existing maintained Julia packages that implement resampling algorithms for multi-class classification problems or that handle both nominal and continuous features. This has served as a primary motivation for the creation of this package.

diff --git a/dev/search_index.js b/dev/search_index.js index b477ba51..3899b2f0 100644 --- a/dev/search_index.js +++ b/dev/search_index.js @@ -1,3 +1,3 @@ var documenterSearchIndex = {"docs": -[{"location":"algorithms/extra_algorithms/#Extras","page":"Extras","title":"Extras","text":"","category":"section"},{"location":"algorithms/extra_algorithms/#Generate-Imbalanced-Data","page":"Extras","title":"Generate Imbalanced Data","text":"","category":"section"},{"location":"algorithms/extra_algorithms/","page":"Extras","title":"Extras","text":"generate_imbalanced_data","category":"page"},{"location":"algorithms/extra_algorithms/#Imbalance.generate_imbalanced_data","page":"Extras","title":"Imbalance.generate_imbalanced_data","text":"generate_imbalanced_data(\n num_rows, num_continuous_feats;\n means=nothing, min_sep=1.0, stds=nothing,\n num_vals_per_category = [],\n class_probs = [0.8, 0.2],\n type= \"ColTable\", insert_y= nothing,\n rng= default_rng(),\n)\n\nGenerate num_rows observations with target y respecting given probabilities of each class. Supports generating continuous features with a specific mean and variance and categorical features given the number of levels in each variable.\n\nArguments\n\nnum_rows::Integer: Number of observations to generate\nnum_continuous_feats::Integer: Number of continuous features to generate\nmeans::AbstractVector=nothing: A vector of means for each continuous feature (must be as long as num_continuous_feats). If nothing, then will be set randomly\nmin_sep::AbstractFloat=1.0: Minimum distance between any two randomly chosen means. Will have no effect if the means are given.\nstds::AbstractVector=nothing: A vector of standard deviations for each continuous feature (must be as long as num_continuous_feats). If nothing, then will be set randomly\nnum_vals_per_category::AbstractVector=[]: A vector of the number of levels of each extra categorical feature. the number of categorical features is inferred from this.\nclass_probs::AbstractVector{<:AbstractFloat}=[0.8, 0.2]: A vector of probabilities of each class. The number of classes is inferred from this vector.\ntype::AbstractString=\"ColTable\": Can be \"Matrix\" or \"ColTable\". In the latter case, a named-tuple of vectors is returned.\ninsert_y::Integer=nothing: If not nothing, insert the class labels column at the given index in the table\nrng::Union{AbstractRNG, Integer}=default_rng(): Random number generator. If integer then used as seed in Random.Xoshiro(seed) if the Julia VERSION supports it. Otherwise, uses Random.MersenneTwister(seed).\n\nReturns\n\nX:: A column table or matrix with generated imbalanced data with num_rows rows and num_continuous_feats + length(num_vals_per_category) columns. If insert_y is specified as in integer then y is also inserted at the specified index as an extra column.\ny::CategoricalArray: An abstract vector of class labels with labels 0, 1, 2, ..., k-1 where k=length(class_probs)\n\nExample\n\nusing Imbalance\nusing Plots\n\nnum_rows = 500\nnum_features = 2\n# generating continuous features given mean and std\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n\tmeans = [1.0, 4.0, [7.0 9.0]],\n\tstds = [1.0, [0.5 0.8], 2.0],\n\tclass_probs=[0.5, 0.2, 0.3],\n\ttype=\"Matrix\",\n\trng = 42,\n)\n\np = plot()\n[scatter!(p, X[:, 1][y.==yi], X[:, 2][y.==yi], label = \"$y=yi$\") for yi in unique(y)]\n\njulia> plot(p)\n\n(Image: generated data)\n\n# generating continuous features with random mean and std\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n min_sep=0.3, \n\tclass_probs=[0.5, 0.2, 0.3],\n\ttype=\"Matrix\",\n\trng = 33,\n)\n\np = plot()\n[scatter!(p, X[:, 1][y.==yi], X[:, 2][y.==yi], label = \"$y=yi$\") for yi in unique(y)]\n\njulia> plot(p)\n\n(Image: generated data)\n\nnum_rows = 500\nnum_features = 2\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n num_vals_per_category = [3, 5, 2],\n\tclass_probs=[0.9, 0.1],\n\tinsert_y=4,\n\ttype=\"ColTable\",\n\trng = 33,\n)\n\njulia> X\n(Column1 = [0.883, 0.9, 0.577 … 0.887,],\n Column2 = [0.578, 0.718, 0.378 … 0.573,],\n Column3 = [2.0, 2.0, 3.0, … 2.0,],\n Column4 = [0.0, 0.0, 0.0, … 0.0,],\n Column5 = [2.0, 3.0, 4.0, … 4.0,],\n Column6 = [1.0, 1.0, 2.0, … 1.0,],)\n\n\n\n\n\n","category":"function"},{"location":"algorithms/extra_algorithms/#Check-Balance-of-Data","page":"Extras","title":"Check Balance of Data","text":"","category":"section"},{"location":"algorithms/extra_algorithms/","page":"Extras","title":"Extras","text":"checkbalance","category":"page"},{"location":"algorithms/extra_algorithms/#Imbalance.checkbalance","page":"Extras","title":"Imbalance.checkbalance","text":"checkbalance(y; reference=\"majority\")\n\nA visual version of StatsBase.countmap that returns nothing and prints how many observations in the dataset belong to each class and their percentage relative to the size of majority or minority class.\n\nArguments\n\ny::AbstractVector: A vector of categorical values to test for imbalance\nreference=\"majority\": Either \"majority\" or \"minority\" and decides whether the percentage should be relative to the size of majority or minority class.\n\nExample\n\nnum_rows = 50000\nnum_features = 2\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n\tclass_probs=[0.8, 0.2],\n\ttype=\"Matrix\",\n\trng = 42,\n)\n\njulia> Imbalance.checkbalance(y; ref=\"majority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 10034 (25.1%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 39966 (100.0%) \n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 10034 (100.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 39966 (398.3%) \n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#Oversampling-Algorithms","page":"Oversampling","title":"Oversampling Algorithms","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"The following table portrays the supported oversampling algorithms, whether the mechanism repeats or generates data and the supported types of data.","category":"page"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"Oversampling Method Mechanism Supported Data Types\nRandom Oversampler Repeat existing data Continuous and/or nominal\nRandom Walk Oversampler Generate synthetic data Continuous and/or nominal\nROSE Generate synthetic data Continuous\nSMOTE Generate synthetic data Continuous\nBorderline SMOTE1 Generate synthetic data Continuous\nSMOTE-N Generate synthetic data Nominal\nSMOTE-NC Generate synthetic data Continuous and nominal","category":"page"},{"location":"algorithms/oversampling_algorithms/#Random-Oversampler","page":"Oversampling","title":"Random Oversampler","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"random_oversample","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.random_oversample","page":"Oversampling","title":"Imbalance.random_oversample","text":"random_oversample(\n X, y; \n ratios=1.0, rng=default_rng(), \n try_preserve_type=true\n)\n\nDescription\n\nNaively oversample a dataset by randomly repeating existing observations with replacement.\n\nPositional Arguments\n\nX: A matrix of real numbers or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\n# apply random oversampling\nXover, yover = random_oversample(X, y; ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the RandomOversampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nRandomOversampler = @load RandomOversampler pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = RandomOversampler(ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate Random Oversampler model\noversampler = RandomOversampler(y_ind; ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#Random-Walk-Oversampler","page":"Oversampling","title":"Random Walk Oversampler","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"random_walk_oversample","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.random_walk_oversample","page":"Oversampling","title":"Imbalance.random_walk_oversample","text":"random_walk_oversample(\n\tX, y, cat_inds;\n\tratios=1.0, rng=default_rng(),\n\ttry_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using random walk oversampling as presented in [1]. \n\nPositional Arguments\n\nX: A matrix of floats or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\ncat_inds::AbstractVector{<:Int}: A vector of the indices of the nominal features. Supplied only if X is a matrix. Otherwise, they are inferred from the table's scitypes.\n\nKeyword Arguments\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\nusing ScientificTypes\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows = 100\nnum_continuous_feats = 3\n# want two categorical features with three and two possible values respectively\nnum_vals_per_category = [3, 2]\n\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, num_vals_per_category, rng=42) \n\t\t\t\t\t\t\t\t \njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\njulia> ScientificTypes.schema(X).scitypes\n(Continuous, Continuous, Continuous, Continuous, Continuous)\n# coerce nominal columns to a finite scitype (multiclass or ordered factor)\nX = coerce(X, :Column4=>Multiclass, :Column5=>Multiclass)\n\n# apply random walk oversampling\nXover, yover = random_walk_oversample(X, y; \n ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the RandomWalkOversampling model and pass the \tpositional arguments (excluding cat_inds) to the transform method. \n\nusing MLJ\nRandomWalkOversampler = @load RandomWalkOversampler pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = RandomWalkOversampler(ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser. Note that only Table input is supported by the MLJ interface for this method.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind \tmust be specified to the constructor to specify which column y is followed by other keyword arguments. \tOnly Xy is provided while applying the transform.\n\nusing Imbalance\nusing ScientificTypes\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_continuous_feats = 3\ny_ind = 2\n\n# generate a table and categorical vector accordingly\nXy, _ = generate_imbalanced_data(num_rows, num_continuous_feats; insert_y=y_ind,\n class_probs= [0.5, 0.2, 0.3], num_vals_per_category=[3, 2],\n rng=42) \n\n# Table must have only finite or continuous scitypes \nXy = coerce(Xy, :Column2=>Multiclass, :Column5=>Multiclass, :Column6=>Multiclass)\n\n# Initiate Random Walk Oversampler model\noversampler = RandomWalkOversampler(y_ind;\n ratios=Dict(1=>1.0, 2=> 0.9, 3=>0.9), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Zhang, H., & Li, M. (2014). RWO-Sampling: A random walk over-sampling approach to imbalanced data classification. Information Fusion, 25, 4-20.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#ROSE","page":"Oversampling","title":"ROSE","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"rose","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.rose","page":"Oversampling","title":"Imbalance.rose","text":"rose(\n X, y; \n s=1.0, ratios=1.0, rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using ROSE (Random Oversampling Examples) algorithm to correct for class imbalance as presented in [1]\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\ns::float=1.0: A parameter that proportionally controls the bandwidth of the Gaussian kernel\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\n# apply ROSE\nXover, yover = rose(X, y; s=0.3, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the ROSE model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nROSE = @load ROSE pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = ROSE(s=0.3, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 200\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate ROSE model\noversampler = ROSE(y_ind; s=0.3, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] G Menardi, N. Torelli, “Training and assessing classification rules with imbalanced data,” Data Mining and Knowledge Discovery, 28(1), pp.92-122, 2014.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#SMOTE","page":"Oversampling","title":"SMOTE","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"smote","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.smote","page":"Oversampling","title":"Imbalance.smote","text":"smote(\n X, y;\n k=5, ratios=1.0, rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using SMOTE (Synthetic Minority Oversampling Techniques) algorithm to correct for class imbalance as presented in [1]\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class.\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\n# apply SMOTE\nXover, yover = smote(X, y; k = 5, ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the SMOTE model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nSMOTE = @load SMOTE pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = SMOTE(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 200\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate SMOTE model\noversampler = SMOTE(y_ind; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#Borderline-SMOTE1","page":"Oversampling","title":"Borderline SMOTE1","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"borderline_smote1","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.borderline_smote1","page":"Oversampling","title":"Imbalance.borderline_smote1","text":"borderline_smote1(\n X, y;\n m=5, k=5, ratios=1.0, rng=default_rng(),\n try_preserve_type=true, verbosity=1\n)\n\nDescription\n\nOversamples a dataset using borderline SMOTE1 algorithm to correct for class imbalance as presented in [1]\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nm::Integer=5: The number of neighbors to consider while checking the BorderlineSMOTE1 condition. Should be within the range 0 < m < N where N is the number of observations in the data. It will be automatically set to N-1 if N ≤ m.\nk::Integer=5: Number of nearest neighbors to consider in the SMOTE part of the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class. It will be automatically set to l-1 for any class with l points where l ≤ k.\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nverbosity::Integer=1: Whenever higher than 0 info regarding the points that will participate in oversampling is logged.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 1000, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n stds=[0.1 0.1 0.1], min_sep=0.01, class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 200 (40.8%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 310 (63.3%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 490 (100.0%) \n\n# apply BorderlineSMOTE1\nXover, yover = borderline_smote1(X, y; m = 3, \n k = 5, ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 392 (80.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 441 (90.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 490 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the BorderlineSMOTE1 model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nBorderlineSMOTE1 = @load BorderlineSMOTE1 pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = BorderlineSMOTE1(m=3, k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 1000\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], min_sep=0.01, insert_y=y_ind, rng=42)\n\n# Initiate BorderlineSMOTE1 Oversampler model\noversampler = BorderlineSMOTE1(y_ind; m=3, k=5, \n ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) \n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Han, H., Wang, W.-Y., & Mao, B.-H. (2005). Borderline-SMOTE: A new over-sampling method in imbalanced data sets learning. In D.S. Huang, X.-P. Zhang, & G.-B. Huang (Eds.), Advances in Intelligent Computing (pp. 878-887). Springer. \n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#SMOTE-N","page":"Oversampling","title":"SMOTE-N","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"smoten","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.smoten","page":"Oversampling","title":"Imbalance.smoten","text":"smoten(\n X, y;\n k=5, ratios=1.0, rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using SMOTE-N (Synthetic Minority Oversampling Techniques-Nominal) algorithm to correct for class imbalance as presented in [1]. This is a variant of SMOTE to deal with datasets where all features are nominal.\n\nPositional Arguments\n\nX: A matrix of integers or a table with element scitypes that subtype Finite. That is, for table inputs each column should have either OrderedFactor or Multiclass as the element scitype.\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class.\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\nusing ScientificTypes\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows = 100\nnum_continuous_feats = 0\n# want two categorical features with three and two possible values respectively\nnum_vals_per_category = [3, 2]\n\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, num_vals_per_category, rng=42) \njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\njulia> ScientificTypes.schema(X).scitypes\n(Count, Count)\n\n# coerce to a finite scitype (multiclass or ordered factor)\nX = coerce(X, autotype(X, :few_to_finite))\n\n# apply SMOTEN\nXover, yover = smoten(X, y; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the SMOTEN model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nSMOTEN = @load SMOTEN pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = SMOTEN(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing ScientificTypes\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_continuous_feats = 0\ny_ind = 2\n# generate a table and categorical vector accordingly\nXy, _ = generate_imbalanced_data(num_rows, num_continuous_feats; insert_y=y_ind,\n class_probs= [0.5, 0.2, 0.3], num_vals_per_category=[3, 2],\n rng=42) \n\n# Table must have only finite scitypes \nXy = coerce(Xy, :Column1=>Multiclass, :Column2=>Multiclass, :Column3=>Multiclass)\n\n# Initiate SMOTEN model\noversampler = SMOTEN(y_ind; k=5, ratios=Dict(1=>1.0, 2=> 0.9, 3=>0.9), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#SMOTE-NC","page":"Oversampling","title":"SMOTE-NC","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"smotenc","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.smotenc","page":"Oversampling","title":"Imbalance.smotenc","text":"smotenc(\n X, y, split_ind;\n k=5, ratios=1.0, knn_tree=\"Brute\", rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using SMOTE-NC (Synthetic Minority Oversampling Techniques-Nominal Continuous) algorithm to correct for class imbalance as presented in [1]. This is a variant of SMOTE to deal with datasets with both nominal and continuous features. \n\nwarning: SMOTE-NC Assumes Continuous Features Exist\nSMOTE-NC will not work if the dataset is purely nominal. In that case, refer to SMOTE-N instead. Meanwhile, if the dataset is purely continuous then it's equivalent to the standard SMOTE.\n\nPositional Arguments\n\nX: A matrix of floats or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\ncat_inds::AbstractVector{<:Int}: A vector of the indices of the nominal features. Supplied only if X is a matrix. Otherwise, they are inferred from the table's scitypes.\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class.\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nknn_tree: Decides the tree used in KNN computations. Either \"Brute\" or \"Ball\". BallTree can be much faster but may lead to inaccurate results.\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\nusing ScientificTypes\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows = 100\nnum_continuous_feats = 3\n# want two categorical features with three and two possible values respectively\nnum_vals_per_category = [3, 2]\n\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, num_vals_per_category, rng=42) \njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\njulia> ScientificTypes.schema(X).scitypes\n(Continuous, Continuous, Continuous, Continuous, Continuous)\n# coerce nominal columns to a finite scitype (multiclass or ordered factor)\nX = coerce(X, :Column4=>Multiclass, :Column5=>Multiclass)\n\n# apply SMOTE-NC\nXover, yover = smotenc(X, y; k = 5, ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the SMOTENC model and pass the positional arguments (excluding cat_inds) to the transform method. \n\nusing MLJ\nSMOTENC = @load SMOTENC pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = SMOTENC(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser. Note that only Table input is supported by the MLJ interface for this method.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing ScientificTypes\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_continuous_feats = 3\ny_ind = 2\n# generate a table and categorical vector accordingly\nXy, _ = generate_imbalanced_data(num_rows, num_continuous_feats; insert_y=y_ind,\n class_probs= [0.5, 0.2, 0.3], num_vals_per_category=[3, 2],\n rng=42) \n\n# Table must have only finite or continuous scitypes \nXy = coerce(Xy, :Column2=>Multiclass, :Column5=>Multiclass, :Column6=>Multiclass)\n\n# Initiate SMOTENC model\noversampler = SMOTENC(y_ind; k=5, ratios=Dict(1=>1.0, 2=> 0.9, 3=>0.9), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.\n\n\n\n\n\n","category":"function"},{"location":"examples/walkthrough/#Introduction","page":"Introduction","title":"Introduction","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In this section of the docs, we will walk you through some examples to demonstrate how you can use Imbalance.jl in your machine learning project. Although we focus on examples, you can learn more about how specific algorithms work by reading this series of blogposts on Medium.","category":"page"},{"location":"examples/walkthrough/#Prerequisites","page":"Introduction","title":"Prerequisites","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In further examples, we will assume familiarity with the CSV, DataFrames, ScientificTypes and MLJ packages, all of which come with excellent documentation. This example is devoted to assuring and enforcing your familiarity with such packages. You can try this all examples in the docs on your browser using Google Colab and you can read more about that in the last section.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \n \"Imbalance\", \"MLJBalancing\", \"ScientificTypes\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing MLJBalancing\nusing ScientificTypes\nusing HTTP: download","category":"page"},{"location":"examples/walkthrough/#Loading-Data","page":"Introduction","title":"Loading Data","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In this example, we will consider the BMI dataset found on Kaggle where the objective is to predict the BMI index of individuals given their gender, weight and height. ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/bmi.csv\", \"./\")\ndf = CSV.read(\"./bmi.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌─────────┬────────┬────────┬───────┐\n│ Gender │ Height │ Weight │ Index │\n│ String7 │ Int64 │ Int64 │ Int64 │\n│ Textual │ Count │ Count │ Count │\n├─────────┼────────┼────────┼───────┤\n│ Male │ 174 │ 96 │ 4 │\n│ Male │ 189 │ 87 │ 2 │\n│ Female │ 185 │ 110 │ 4 │\n│ Female │ 195 │ 104 │ 3 │\n│ Male │ 149 │ 61 │ 3 │\n└─────────┴────────┴────────┴───────┘\n\n\n┌ Warning: Reading one byte at a time from HTTP.Stream is inefficient.\n│ Use: io = BufferedInputStream(http::HTTP.Stream) instead.\n│ See: https://github.com/BioJulia/BufferedStreams.jl\n└ @ HTTP.Streams /Users/essam/.julia/packages/HTTP/SN7VW/src/Streams.jl:240\n┌ Info: Downloading\n│ source = https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/bmi.csv\n│ dest = ./bmi.csv\n│ progress = NaN\n│ time_taken = 0.0 s\n│ time_remaining = NaN s\n│ average_speed = 7.933 MiB/s\n│ downloaded = 8.123 KiB\n│ remaining = ∞ B\n│ total = ∞ B\n└ @ HTTP /Users/essam/.julia/packages/HTTP/SN7VW/src/download.jl:132","category":"page"},{"location":"examples/walkthrough/#Coercing-Data","page":"Introduction","title":"Coercing Data","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Typical models from MLJ assume that elements in each column of a table have some scientific type as defined by the ScientificTypes.jl package. Among the many types defined by the package, we are interested in Multiclass, OrderedFactor which fall under the Finite abstract type and Continuous and Count which fall under the Infinite abstract type.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"One motivation for this package is that it's not generally obvious whether numerical data in an input table is of continuous type or categorical type given that numbers can describe both. Meanwhile, it's problematic if a model treats numerical data as say Continuous or Count when it's in reality nominal (i.e., Multiclass) or ordinal (i.e., OrderedFactor).","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"We can use schema(df) to see how each features is currently going to be interpreted by the resampling algorithms: ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌────────┬──────────┬─────────┐\n│ names │ scitypes │ types │\n├────────┼──────────┼─────────┤\n│ Gender │ Textual │ String7 │\n│ Height │ Count │ Int64 │\n│ Weight │ Count │ Int64 │\n│ Index │ Count │ Int64 │\n└────────┴──────────┴─────────┘","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"To change encodings that are leading to incorrect interpretations (true for all variable in this example), we use the coerce method, as follows:","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"df = coerce(df,\n :Gender => Multiclass,\n :Height => Continuous,\n :Weight => Continuous,\n :Index => OrderedFactor)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌────────┬──────────────────┬───────────────────────────────────┐\n│ names │ scitypes │ types │\n├────────┼──────────────────┼───────────────────────────────────┤\n│ Gender │ Multiclass{2} │ CategoricalValue{String7, UInt32} │\n│ Height │ Continuous │ Float64 │\n│ Weight │ Continuous │ Float64 │\n│ Index │ OrderedFactor{6} │ CategoricalValue{Int64, UInt32} │\n└────────┴──────────────────┴───────────────────────────────────┘","category":"page"},{"location":"examples/walkthrough/#Unpacking-and-Splitting-Data","page":"Introduction","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"y, X = unpack(df, ==(:Index); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌───────────────────────────────────┬────────────┬────────────┐\n│ Gender │ Height │ Weight │\n│ CategoricalValue{String7, UInt32} │ Float64 │ Float64 │\n│ Multiclass{2} │ Continuous │ Continuous │\n├───────────────────────────────────┼────────────┼────────────┤\n│ Female │ 173.0 │ 82.0 │\n│ Female │ 187.0 │ 121.0 │\n│ Male │ 144.0 │ 145.0 │\n│ Male │ 156.0 │ 74.0 │\n│ Male │ 167.0 │ 151.0 │\n└───────────────────────────────────┴────────────┴────────────┘","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Splitting the data into train and test portions is also easy using MLJ's partition function. stratify=y guarantees that the data is distributed in the same proportions as the original dataset in both splits which is more representative of the real world.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"(X_train, X_test), (y_train, y_test) = partition(\n\t(X, y),\n\t0.8,\n\tmulti = true,\n\tshuffle = true,\n\tstratify = y,\n\trng = Random.Xoshiro(42)\n)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"((399×3 DataFrame\n Row │ Gender Height Weight \n │ Cat… Float64 Float64 \n─────┼──────────────────────────\n 1 │ Female 179.0 150.0\n 2 │ Male 141.0 80.0\n 3 │ Male 179.0 152.0\n 4 │ Male 187.0 138.0\n 5 │ Male 148.0 155.0\n 6 │ Female 192.0 101.0\n 7 │ Male 145.0 78.0\n 8 │ Female 162.0 159.0\n ⋮ │ ⋮ ⋮ ⋮\n 393 │ Female 161.0 154.0\n 394 │ Female 172.0 109.0\n 395 │ Female 163.0 159.0\n 396 │ Female 186.0 146.0\n 397 │ Male 194.0 106.0\n 398 │ Female 167.0 153.0\n 399 │ Female 162.0 64.0\n 384 rows omitted, 101×3 DataFrame\n Row │ Gender Height Weight \n │ Cat… Float64 Float64 \n─────┼──────────────────────────\n 1 │ Female 157.0 56.0\n 2 │ Male 180.0 75.0\n 3 │ Female 157.0 110.0\n 4 │ Female 182.0 143.0\n 5 │ Male 165.0 104.0\n 6 │ Male 182.0 73.0\n 7 │ Male 165.0 68.0\n 8 │ Male 166.0 107.0\n ⋮ │ ⋮ ⋮ ⋮\n 95 │ Male 163.0 137.0\n 96 │ Female 188.0 99.0\n 97 │ Female 146.0 123.0\n 98 │ Male 186.0 68.0\n 99 │ Female 140.0 76.0\n 100 │ Female 168.0 139.0\n 101 │ Male 180.0 149.0\n 86 rows omitted), (CategoricalArrays.CategoricalValue{Int64, UInt32}[5, 5, 5, 4, 5, 3, 4, 5, 5, 5 … 5, 4, 4, 5, 4, 5, 5, 3, 5, 2], CategoricalArrays.CategoricalValue{Int64, UInt32}[2, 2, 5, 5, 4, 2, 2, 4, 3, 3 … 2, 0, 0, 5, 3, 5, 2, 4, 5, 5]))","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"⚠️ Always split the data before oversampling. If your test data has oversampled observations then train-test contamination has occurred; novel observations will not come from the oversampling function.","category":"page"},{"location":"examples/walkthrough/#Oversampling","page":"Introduction","title":"Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0: ▇▇▇ 13 (6.6%) \n1: ▇▇▇▇▇▇ 22 (11.1%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 68 (34.3%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 69 (34.8%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 130 (65.7%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 198 (100.0%)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Looks like we have a class imbalance problem. Let's set the desired ratios so that the first two classes are 30% of the majority class, the second two are 50% of the majority class and the rest as is (ignore in the dictionary)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"ratios = Dict(0=>0.3, 1=>0.3, 2=>0.5, 3=>0.5) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Dict{Int64, Float64} with 4 entries:\n 0 => 0.3\n 2 => 0.5\n 3 => 0.5\n 1 => 0.3","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Let's use random oversampling to oversample the data. This particular model does not care about the scientific types of the data. It takes X and y as positional arguments and ratios and rng are the main keyword arguments","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Xover, yover = random_oversample(X_train, y_train; ratios, rng=42) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"(514×3 DataFrame\n Row │ Gender Height Weight \n │ Cat… Float64 Float64 \n─────┼──────────────────────────\n 1 │ Female 179.0 150.0\n 2 │ Male 141.0 80.0\n 3 │ Male 179.0 152.0\n 4 │ Male 187.0 138.0\n 5 │ Male 148.0 155.0\n 6 │ Female 192.0 101.0\n 7 │ Male 145.0 78.0\n 8 │ Female 162.0 159.0\n ⋮ │ ⋮ ⋮ ⋮\n 508 │ Female 196.0 50.0\n 509 │ Male 193.0 54.0\n 510 │ Male 182.0 50.0\n 511 │ Male 190.0 50.0\n 512 │ Male 190.0 50.0\n 513 │ Male 198.0 50.0\n 514 │ Male 198.0 50.0\n 499 rows omitted, CategoricalArrays.CategoricalValue{Int64, UInt32}[5, 5, 5, 4, 5, 3, 4, 5, 5, 5 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"checkbalance(yover)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 47 (29.7%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 47 (29.7%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 79 (50.0%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 79 (50.0%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 104 (65.8%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 158 (100.0%)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"This indeeds aligns with the desired ratios we have set earlier.","category":"page"},{"location":"examples/walkthrough/#Training-the-Model","page":"Introduction","title":"Training the Model","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"5-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Let's go for a decision tree form BetaML","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"import Pkg; Pkg.add(\"BetaML\")","category":"page"},{"location":"examples/walkthrough/#Before-Oversampling","page":"Introduction","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"# 1. Load the model\nDecisionTreeClassifier = @load DecisionTreeClassifier pkg=BetaML verbosity=0\n\n# 2. Instantiate it\nmodel = DecisionTreeClassifier(max_depth=5, rng=Random.Xoshiro(42))\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌ Info: Training machine(DecisionTreeClassifier(max_depth = 5, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 5, …)\n args: \n 1:\tSource @027 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{2}}}}\n 2:\tSource @092 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/walkthrough/#After-Oversampling","page":"Introduction","title":"After Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌ Info: Training machine(DecisionTreeClassifier(max_depth = 5, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 5, …)\n args: \n 1:\tSource @592 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{2}}}}\n 2:\tSource @711 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/walkthrough/#Evaluating-the-Model","page":"Introduction","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"To evaluate the model, we will use the balanced accuracy metric which equally account for all classes. For instance, if we have two classes and we correctly classify 100% of the examples in the first and 50% of the examples in the second then the balanced accuracy is (100+50)2=75. This holds regardless to how big or small each class is.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"The predict_mode will return a vector of predictions given X_test and the fitted machine. It's different in that predict in not returning probablities the model assigns to each class; instead, it returns the classes with the maximum probabilities; i.e., the modes.","category":"page"},{"location":"examples/walkthrough/#Before-Oversampling-2","page":"Introduction","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"y_pred = predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0.62","category":"page"},{"location":"examples/walkthrough/#After-Oversampling-2","page":"Introduction","title":"After Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"y_pred_over = predict_mode(mach_over, X_test)\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0.75","category":"page"},{"location":"examples/walkthrough/#Evaluating-the-Model-Revisited","page":"Introduction","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a 13% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/walkthrough/#Before-Oversampling-3","page":"Introduction","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.621 │ 0.0913 │ [0.593, 0.473, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Under the normality assumption, the 95% confidence interval is 62.1±9.13% which is pretty big. Let's see how it looks after oversampling.","category":"page"},{"location":"examples/walkthrough/#After-Oversampling-3","page":"Introduction","title":"After Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.RandomOversampler(ratios=ratios, rng=42)\nmodel = DecisionTreeClassifier(max_depth=5, rng=Random.Xoshiro(42))\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = DecisionTreeClassifier(max_depth = 5, …), …)\n args: \n 1:\tSource @099 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{2}}}}\n 2:\tSource @071 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"predict_mode(mach_over, X_test) == y_pred_over","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"true","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.7 │ 0.0717 │ [0.7, 0.536, 0. ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"This results in an interval 70±7.2% which can be viewed as a reasonable improvement over 62.1±9.13%. The uncertainty in the intervals can be explained by the fact that the dataset is small with many classes.","category":"page"},{"location":"contributing/#Directory-Structure","page":"Contributing","title":"Directory Structure","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"The folder structure is as follows:","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":".\n├── Imbalance.jl # entry point to package\n├── generic_resample.jl # functions used in all resampling methods\n├── generic_encoder.jl # used in all resampling methods that deal with categorical data\n├── table_wrappers.jl # generalizes a function that operates on matrices to tables\n├── class_counts.jl # used to compute number of data points to add or remove\n├── common # has julia files for common docs, error strings and utils\n├── distance_metrics # has distance metrics used by some resampling methods\n├── oversampling_methods # all oversampling methods live here\n├── undersampling_methods # all undersampling methods live here\n└── extras.jl # extra functions like generating data or checking balance","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"The purpose of each file is further documented therein at the beginning of the file. The files are ordered here in the recommended order of checking. ","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Any method resampling method implemented in the oversampling_methods or undersampling_methods folder takes the following structure:","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"├── resample_method # contains implementation and interfaces for a resampling method\n│ ├── interface_mlj.jl # implements MLJ interface for the method\n│ ├── interface_tables.jl # implements Tables.jl interface for the method\n│ └── resample_method.jl # implements the method itself (pure functional interface)","category":"page"},{"location":"contributing/#Contribution","page":"Contributing","title":"Contribution","text":"","category":"section"},{"location":"contributing/#Reporting-Problems-or-Seeking-Support","page":"Contributing","title":"Reporting Problems or Seeking Support","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Do not hesitate to post a Github issue with your question or problem.","category":"page"},{"location":"contributing/#Adding-New-Resampling-Methods","page":"Contributing","title":"Adding New Resampling Methods","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Make a new folder resample_method for the method in the oversampling_methods or undersampling_methods\nImplement in resample_method/resample_method.jl the method over matrices for one minority class\nUse generic_oversample.jl to generalize it to work on the whole data\nUse table_wrapper.jl to generalize the method to work on tables and possibly use generic_encoder.jl\nImplement the MLJ interface for the method in resample_method/interface_mlj\nImplement the TableTransforms interface for the method in resample_method/interface_tables.jl\nUse the rest of the files according to their description\nTesting and documentation should be done in parallel","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Surely, you can ignore ignore the third step if the algorithm you are implementing does not operate in \"per-class\" sense.","category":"page"},{"location":"contributing/#Hot-algorithms-to-add","page":"Contributing","title":"🔥 Hot algorithms to add","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"K-Means SMOTE: Takes care of where exactly to generate more points using SMOTE by factoring in \"within class imbalance\". This may be also easily generalized to algorithms beyond SMOTE.\nCondensedNearestNeighbors: Undersamples the dataset such as to perserve the decision boundary by KNN\nBorderlineSMOTE2: A small modification of the BorderlineSMOTE1 condition\nRepeatedENNUndersampler: Simply repeats ENNUndersampler multiple times","category":"page"},{"location":"contributing/#Adding-New-Tutorials","page":"Contributing","title":"Adding New Tutorials","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Make a new notebook with the tutorial in the examples folder found in docs/src/examples\nRun the notebook so that the output is shown below each cell\nIf the notebook produces visuals then save and load them in the notebook\nConvert it to markdown by using Python to run from convert import convert_to_md; convert_to_md('')\nSet a title, description, image and links for it in the dictionary found in docs/examples.jl\nFor the colab link, you do not need to upload anything just follow the link pattern in the file","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Effect-of-ENN-Hyperparameters","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \n \"ScientificTypes\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing ScientificTypes\nusing Plots, Measures\nusing HTTP: download","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Loading-Data","page":"Effect of ENN Hyperparameters","title":"Loading Data","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"In this example, we will consider the BMI dataset found on Kaggle where the objective is to predict the BMI index of individuals given their gender, weight and height. ","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/effect_of_k_enn/bmi.csv\", \"./\")\n\ndf = CSV.read(\"./bmi.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌─────────┬────────┬────────┬───────┐\n│ Gender │ Height │ Weight │ Index │\n│ String7 │ Int64 │ Int64 │ Int64 │\n│ Textual │ Count │ Count │ Count │\n├─────────┼────────┼────────┼───────┤\n│ Male │ 174 │ 96 │ 4 │\n│ Male │ 189 │ 87 │ 2 │\n│ Female │ 185 │ 110 │ 4 │\n│ Female │ 195 │ 104 │ 3 │\n│ Male │ 149 │ 61 │ 3 │\n└─────────┴────────┴────────┴───────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"We will drop the gender attribute for purposes of visualization and to have more options for the model.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"select!(df, Not(:Gender)) |> pretty","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Coercing-Data","page":"Effect of ENN Hyperparameters","title":"Coercing Data","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌────────┬──────────┬───────┐\n│ names │ scitypes │ types │\n├────────┼──────────┼───────┤\n│ Height │ Count │ Int64 │\n│ Weight │ Count │ Int64 │\n│ Index │ Count │ Int64 │\n└────────┴──────────┴───────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Weight and Height should be Continuous and Index should be an OrderedFactor","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"df = coerce(df,\n :Height => Continuous,\n :Weight => Continuous,\n :Index => OrderedFactor)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌────────┬──────────────────┬─────────────────────────────────┐\n│ names │ scitypes │ types │\n├────────┼──────────────────┼─────────────────────────────────┤\n│ Height │ Continuous │ Float64 │\n│ Weight │ Continuous │ Float64 │\n│ Index │ OrderedFactor{6} │ CategoricalValue{Int64, UInt32} │\n└────────┴──────────────────┴─────────────────────────────────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Unpacking-Data","page":"Effect of ENN Hyperparameters","title":"Unpacking Data","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"y, X = unpack(df, ==(:Index); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌────────────┬────────────┐\n│ Height │ Weight │\n│ Float64 │ Float64 │\n│ Continuous │ Continuous │\n├────────────┼────────────┤\n│ 173.0 │ 82.0 │\n│ 187.0 │ 121.0 │\n│ 144.0 │ 145.0 │\n│ 156.0 │ 74.0 │\n│ 167.0 │ 151.0 │\n└────────────┴────────────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"We will skip splitting the data since the main purpose of this tutorial is visualization.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Undersampling","page":"Effect of ENN Hyperparameters","title":"Undersampling","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Before undersampling, let's check the balance of the data","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"checkbalance(y; ref=\"minority\")","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"0: ▇▇▇ 13 (100.0%) \n1: ▇▇▇▇▇▇ 22 (169.2%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 68 (523.1%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 69 (530.8%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 130 (1000.0%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 198 (1523.1%)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Let's use ENN undersampling to undersample the data. ENN undersamples the data by \"cleaning it out\" or in another words deleting any point that violates a certain condition. We can limit the number of points that are deleted by setting the min_ratios parameter. ","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"We will set k=1 and keep_condition=\"only mode\" which means that any point with a label that is not the only most common one amongst its 1-nearest neighbors will be deleted (i.e., must have same label as its nearest neighbor). By setting min_ratios=1.0 we constraint that points should never be deleted form any class if it's ratio relative to the minority class will be less than 1.0. This also means that no points will be deleted from the minority class.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"X_under, y_under = enn_undersample(\n\tX,\n\ty;\n\tk = 1,\n\tkeep_condition = \"only mode\",\n\tmin_ratios=0.01,\n\trng = 42,\n)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"(448×2 DataFrame\n Row │ Height Weight \n │ Float64 Float64 \n─────┼──────────────────\n 1 │ 173.0 82.0\n 2 │ 182.0 70.0\n 3 │ 156.0 52.0\n 4 │ 172.0 67.0\n 5 │ 162.0 58.0\n 6 │ 180.0 75.0\n 7 │ 190.0 83.0\n 8 │ 195.0 81.0\n ⋮ │ ⋮ ⋮\n 442 │ 196.0 50.0\n 443 │ 191.0 54.0\n 444 │ 185.0 52.0\n 445 │ 182.0 50.0\n 446 │ 198.0 50.0\n 447 │ 198.0 50.0\n 448 │ 181.0 51.0\n 433 rows omitted, CategoricalArrays.CategoricalValue{Int64, UInt32}[2, 2, 2, 2, 2, 2, 2, 2, 2, 2 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"checkbalance(y_under; ref=\"minority\")","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"0: ▇▇▇ 11 (100.0%) \n1: ▇▇▇▇▇ 19 (172.7%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 56 (509.1%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 58 (527.3%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 115 (1045.5%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 189 (1718.2%)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"This indeeds aligns with the desired ratios we have set earlier.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Training-the-Model","page":"Effect of ENN Hyperparameters","title":"Training the Model","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"models(matching(X_under, y_under))","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"53-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Let's go for an SVM from LIBSVM","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"import Pkg; Pkg.add(\"LIBSVM\")\nimport LIBSVM;","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":" Updating registry at `~/.julia/registries/General`\n Updating git-repo `https://github.com/JuliaRegistries/General.git`\n Resolving package versions...\n No Changes to `~/Documents/GitHub/Imbalance.jl/docs/Project.toml`\n No Changes to `~/Documents/GitHub/Imbalance.jl/docs/Manifest.toml`","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Before-Undersampling","page":"Effect of ENN Hyperparameters","title":"Before Undersampling","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"# 1. Load the model\nSVC = @load SVC pkg=LIBSVM\n\n# 2. Instantiate it\nmodel = SVC(kernel=LIBSVM.Kernel.RadialBasis, gamma=0.01) ## instance\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X, y)\n\n# 4. fit the machine learning model\nfit!(mach)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"import MLJLIBSVMInterface ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @987 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @104 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#After-Undersampling","page":"Effect of ENN Hyperparameters","title":"After Undersampling","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"# 3. Wrap it with the data in a machine\nmach_under = machine(model, X_under, y_under)\n\n# 4. fit the machine learning model\nfit!(mach_under)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @123 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @423 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Plot-Decision-Boundaries","page":"Effect of ENN Hyperparameters","title":"Plot Decision Boundaries","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Construct ranges for each feature and consecutively a grid","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"height_range =\n\trange(minimum(X.Height) - 1, maximum(X.Height) + 1, length = 400)\nweight_range =\nrange(minimum(X.Weight) - 1, maximum(X.Weight) + 1, length = 400)\ngrid_points = [(h, w) for h in height_range, w in weight_range]","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"400×400 Matrix{Tuple{Float64, Float64}}:\n (139.0, 49.0) (139.0, 49.2807) (139.0, 49.5614) … (139.0, 161.0)\n (139.153, 49.0) (139.153, 49.2807) (139.153, 49.5614) (139.153, 161.0)\n (139.306, 49.0) (139.306, 49.2807) (139.306, 49.5614) (139.306, 161.0)\n (139.459, 49.0) (139.459, 49.2807) (139.459, 49.5614) (139.459, 161.0)\n (139.612, 49.0) (139.612, 49.2807) (139.612, 49.5614) (139.612, 161.0)\n (139.764, 49.0) (139.764, 49.2807) (139.764, 49.5614) … (139.764, 161.0)\n (139.917, 49.0) (139.917, 49.2807) (139.917, 49.5614) (139.917, 161.0)\n (140.07, 49.0) (140.07, 49.2807) (140.07, 49.5614) (140.07, 161.0)\n (140.223, 49.0) (140.223, 49.2807) (140.223, 49.5614) (140.223, 161.0)\n (140.376, 49.0) (140.376, 49.2807) (140.376, 49.5614) (140.376, 161.0)\n ⋮ ⋱ \n (198.777, 49.0) (198.777, 49.2807) (198.777, 49.5614) (198.777, 161.0)\n (198.93, 49.0) (198.93, 49.2807) (198.93, 49.5614) (198.93, 161.0)\n (199.083, 49.0) (199.083, 49.2807) (199.083, 49.5614) (199.083, 161.0)\n (199.236, 49.0) (199.236, 49.2807) (199.236, 49.5614) (199.236, 161.0)\n (199.388, 49.0) (199.388, 49.2807) (199.388, 49.5614) … (199.388, 161.0)\n (199.541, 49.0) (199.541, 49.2807) (199.541, 49.5614) (199.541, 161.0)\n (199.694, 49.0) (199.694, 49.2807) (199.694, 49.5614) (199.694, 161.0)\n (199.847, 49.0) (199.847, 49.2807) (199.847, 49.5614) (199.847, 161.0)\n (200.0, 49.0) (200.0, 49.2807) (200.0, 49.5614) (200.0, 161.0)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Evaluate the grid with the machine before and after undersampling","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"grid_predictions =[\n predict(mach, Tables.table(reshape(collect(point), 1, 2)))[1] for\n \tpoint in grid_points\n ]\n \ngrid_predictions_under = [\n predict(mach_under, Tables.table(reshape(collect(point), 1, 2)))[1] for\n point in grid_points\n]","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Make two contour plots using the grid predictions before and after oversampling","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"colors = [:green, :aqua, :violet, :red, :blue, :yellow]\np = contourf(weight_range, height_range, grid_predictions,\nlevels = 6, color = colors, colorbar = false)\np_under = contourf(weight_range, height_range, grid_predictions_under,\nlevels = 6, color = colors, colorbar = false)\nprintln()","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"labels = unique(y)\ncolors = Dict(\n\t0 => \"green\",\n\t1 => \"cyan3\",\n\t2 => \"violet\",\n\t3 => \"red\",\n\t4 => \"dodgerblue\",\n\t5 => \"gold2\",\n)\n\nfor label in labels\n\tscatter!(p, X.Weight[y.==label], X.Height[y.==label],\n\t\tcolor = colors[label], label = label, markerstrokewidth = 1.5,\n\t\ttitle = \"Before Undersampling\")\n\tscatter!(p_under, X_under.Weight[y_under.==label], X_under.Height[y_under.==label],\n\t\tcolor = colors[label], label = label, markerstrokewidth = 1.5,\n\t\ttitle = \"After Undersampling\")\nend\n\nplot_res = plot(\n\tp,\n\tp_under,\n\tlayout = (1, 2),\n\txlabel = \"Height\",\n\tylabel = \"Width\",\n\tsize = (1200, 450),\n\tmargin = 5mm, dpi = 200,\n\tlegend = :outerbottomright,\n)\nsavefig(plot_res, \"./assets/ENN-before-after.png\")\n","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"(Image: enn comparison)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Effect-of-k-Hyperparameter","page":"Effect of ENN Hyperparameters","title":"Effect of k Hyperparameter","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Now let's study the cleaning effect as k increases for all types of keep conditions of undersampling.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"anim = @animate for k ∈ 1:15\n\tconditions = [\"exists\", \"mode\", \"only mode\", \"all\"]\n\tplots = [plot() for _ in 1:4]\n\tdata_list = []\n\n\tfor i in 1:4\n\n\t\tX_under, y_under = enn_undersample(\n\t\t\tX,\n\t\t\ty;\n\t\t\tk = k,\n\t\t\tkeep_condition = conditions[i],\n\t\t\tmin_ratios = 0.01,\n\t\t\trng = 42,\n\t\t)\n\n\t\t# fit machine\n\t\tmach_under = machine(model, X_under, y_under)\n\t\tfit!(mach_under, verbosity = 0)\n\n\t\t# grid predictions\n\t\tgrid_predictions_under = [\n\t\t\tpredict(mach_under, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\t\t\tpoint in grid_points\n\t\t]\n\n\t\t# plot\n\t\tcolors = [:green, :aqua, :violet, :red, :blue, :yellow]\n\t\tcontourf!(plots[i], weight_range, height_range, grid_predictions_under,\n\t\t\tlevels = 6, color = colors, colorbar = false)\n\n\t\tcolors = Dict(\n\t\t\t0 => \"green\",\n\t\t\t1 => \"cyan3\",\n\t\t\t2 => \"violet\",\n\t\t\t3 => \"red\",\n\t\t\t4 => \"dodgerblue\",\n\t\t\t5 => \"gold2\",\n\t\t)\n\t\tfor label in labels\n\t\t\tscatter!(plots[i], X_under.Weight[y_under.==label],\n\t\t\t\tX_under.Height[y_under.==label],\n\t\t\t\tcolor = colors[label], label = label, markerstrokewidth = 1.5,\n\t\t\t\ttitle = \"$(conditions[i])\", legend = ((i == 2) ? :bottomright : :none))\n\t\tend\n\t\tplot!(\n\t\t\tplots[1], plots[2], plots[3], plots[4],\n\t\t\tlayout = (1, 4),\n\t\t\tsize = (1300, 420),\n\t\t\tplot_title = \"Undersampling with k =$k\",\n\t\t)\n\tend\n\tplot!(dpi = 150)\nend\n","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"gif(anim, \"./assets/enn-k-animation.gif\", fps=1)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"(Image: enn-gif-hyperparameter)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"As we can see, the most constraining condition is all. It deletes any point where the label is different than any of the nearest k neighbors which also explains why it's the most sensitive to the hyperparameter k.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#SMOTEN-on-Mushroom-Data","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing MLJBalancing\nusing StatsBase\nusing ScientificTypes\nusing Plots\nusing HTTP: download","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Loading-Data","page":"SMOTEN on Mushroom Data","title":"Loading Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"In this example, we will consider the Mushroom dataset found on Kaggle for the objective of predicting mushroom odour given various features about the mushroom.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/smoten_mushroom/mushrooms.csv\", \"./\")\ndf = CSV.read(\"./mushrooms.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌─────────┬───────────┬─────────────┬───────────┬─────────┬─────────┬─────────────────┬──────────────┬───────────┬────────────┬─────────────┬────────────┬──────────────────────────┬──────────────────────────┬────────────────────────┬────────────────────────┬───────────┬────────────┬─────────────┬───────────┬───────────────────┬────────────┬─────────┐\n│ class │ cap-shape │ cap-surface │ cap-color │ bruises │ odor │ gill-attachment │ gill-spacing │ gill-size │ gill-color │ stalk-shape │ stalk-root │ stalk-surface-above-ring │ stalk-surface-below-ring │ stalk-color-above-ring │ stalk-color-below-ring │ veil-type │ veil-color │ ring-number │ ring-type │ spore-print-color │ population │ habitat │\n│ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │\n│ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │\n├─────────┼───────────┼─────────────┼───────────┼─────────┼─────────┼─────────────────┼──────────────┼───────────┼────────────┼─────────────┼────────────┼──────────────────────────┼──────────────────────────┼────────────────────────┼────────────────────────┼───────────┼────────────┼─────────────┼───────────┼───────────────────┼────────────┼─────────┤\n│ p │ x │ s │ n │ t │ p │ f │ c │ n │ k │ e │ e │ s │ s │ w │ w │ p │ w │ o │ p │ k │ s │ u │\n│ e │ x │ s │ y │ t │ a │ f │ c │ b │ k │ e │ c │ s │ s │ w │ w │ p │ w │ o │ p │ n │ n │ g │\n│ e │ b │ s │ w │ t │ l │ f │ c │ b │ n │ e │ c │ s │ s │ w │ w │ p │ w │ o │ p │ n │ n │ m │\n│ p │ x │ y │ w │ t │ p │ f │ c │ n │ n │ e │ e │ s │ s │ w │ w │ p │ w │ o │ p │ k │ s │ u │\n│ e │ x │ s │ g │ f │ n │ f │ w │ b │ k │ t │ e │ s │ s │ w │ w │ p │ w │ o │ e │ n │ a │ g │\n└─────────┴───────────┴─────────────┴───────────┴─────────┴─────────┴─────────────────┴──────────────┴───────────┴────────────┴─────────────┴────────────┴──────────────────────────┴──────────────────────────┴────────────────────────┴────────────────────────┴───────────┴────────────┴─────────────┴───────────┴───────────────────┴────────────┴─────────┘","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Visualize-the-Data","page":"SMOTEN on Mushroom Data","title":"Visualize the Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Since this dataset is composed only of categorical features, a bar chart for each column is a good way to visualize the data.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# Create a bar chart for each column\nbar_charts = []\nfor col in names(df)\n counts = countmap(df[!, col])\n k, v = collect(keys(counts)), collect(values(counts))\n if length(k) < 20\n push!(bar_charts, bar(k, v, legend=false, title=col))\n end\nend\n\n# Combine bar charts into a grid layout with specified plot size\nplot_res = plot(bar_charts..., layout=(5, 5), \n size=(1300, 1200), \n plot_title=\"Value Frequencies for each Categorical Variable\")\nsavefig(plot_res, \"./assets/mushroom-bar-charts.png\")","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"(Image: Mushroom Features Plots)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We will take the mushroom odour as our target and all the rest as features. ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Coercing-Data","page":"SMOTEN on Mushroom Data","title":"Coercing Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Typical models from MLJ assume that elements in each column of a table have some scientific type as defined by the ScientificTypes.jl package. It's often necessary to coerce the types inferred by default to the appropriate type.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌──────────────────────────┬──────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────────────────┼──────────┼─────────┤\n│ class │ Textual │ String1 │\n│ cap-shape │ Textual │ String1 │\n│ cap-surface │ Textual │ String1 │\n│ cap-color │ Textual │ String1 │\n│ bruises │ Textual │ String1 │\n│ odor │ Textual │ String1 │\n│ gill-attachment │ Textual │ String1 │\n│ gill-spacing │ Textual │ String1 │\n│ gill-size │ Textual │ String1 │\n│ gill-color │ Textual │ String1 │\n│ stalk-shape │ Textual │ String1 │\n│ stalk-root │ Textual │ String1 │\n│ stalk-surface-above-ring │ Textual │ String1 │\n│ stalk-surface-below-ring │ Textual │ String1 │\n│ stalk-color-above-ring │ Textual │ String1 │\n│ stalk-color-below-ring │ Textual │ String1 │\n│ ⋮ │ ⋮ │ ⋮ │\n└──────────────────────────┴──────────┴─────────┘\n 7 rows omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"For instance, here we need to coerce all the data to Multiclass as they are all nominal variables. Textual would be the right type for natural language processing models. Instead of typing in each column manually, autotype lets us perform mass conversion using pre-defined rules.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"df = coerce(df, autotype(df, :few_to_finite))\nScientificTypes.schema(df)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌──────────────────────────┬────────────────┬───────────────────────────────────\n│ names │ scitypes │ types ⋯\n├──────────────────────────┼────────────────┼───────────────────────────────────\n│ class │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ cap-shape │ Multiclass{6} │ CategoricalValue{String1, UInt32 ⋯\n│ cap-surface │ Multiclass{4} │ CategoricalValue{String1, UInt32 ⋯\n│ cap-color │ Multiclass{10} │ CategoricalValue{String1, UInt32 ⋯\n│ bruises │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ odor │ Multiclass{9} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-attachment │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-spacing │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-size │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-color │ Multiclass{12} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-shape │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-root │ Multiclass{5} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-surface-above-ring │ Multiclass{4} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-surface-below-ring │ Multiclass{4} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-color-above-ring │ Multiclass{9} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-color-below-ring │ Multiclass{9} │ CategoricalValue{String1, UInt32 ⋯\n│ ⋮ │ ⋮ │ ⋮ ⋱\n└──────────────────────────┴────────────────┴───────────────────────────────────\n 1 column and 7 rows omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Unpacking-and-Splitting-Data","page":"SMOTEN on Mushroom Data","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"y, X = unpack(df, ==(:odor); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┐\n│ class │ cap-shape │ cap-surface │ cap-color │ bruises │ gill-attachment │ gill-spacing │ gill-size │ gill-color │ stalk-shape │ stalk-root │ stalk-surface-above-ring │ stalk-surface-below-ring │ stalk-color-above-ring │ stalk-color-below-ring │ veil-type │ veil-color │ ring-number │ ring-type │ spore-print-color │ population │ habitat │\n│ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │\n│ Multiclass{2} │ Multiclass{6} │ Multiclass{4} │ Multiclass{10} │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{12} │ Multiclass{2} │ Multiclass{5} │ Multiclass{4} │ Multiclass{4} │ Multiclass{9} │ Multiclass{9} │ Multiclass{1} │ Multiclass{4} │ Multiclass{3} │ Multiclass{5} │ Multiclass{9} │ Multiclass{6} │ Multiclass{7} │\n├───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┤\n│ e │ f │ f │ n │ t │ f │ c │ b │ w │ t │ b │ s │ s │ g │ g │ p │ w │ o │ p │ k │ v │ d │\n│ e │ f │ f │ n │ t │ f │ c │ b │ w │ t │ b │ s │ s │ w │ p │ p │ w │ o │ p │ n │ y │ d │\n│ e │ b │ s │ y │ t │ f │ c │ b │ k │ e │ c │ s │ s │ w │ w │ p │ w │ o │ p │ k │ s │ g │\n│ p │ f │ y │ e │ f │ f │ c │ b │ w │ e │ c │ k │ y │ c │ c │ p │ w │ n │ n │ w │ c │ d │\n│ e │ x │ y │ n │ f │ f │ w │ n │ w │ e │ b │ f │ f │ w │ n │ p │ w │ o │ e │ w │ v │ l │\n└───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┘","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Splitting the data into train and test portions is also easy using MLJ's partition function. stratify=y guarantees that the data is distributed in the same proportions as the original dataset in both splits which is more representative of the real world.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"train_inds, test_inds = partition(eachindex(y), 0.8, shuffle=true, stratify=y, rng=Random.Xoshiro(42))\nX_train, X_test = X[train_inds, :], X[test_inds, :]\ny_train, y_test = y[train_inds], y[test_inds]","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"(CategoricalArrays.CategoricalValue{String1, UInt32}[String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"n\"), String1(\"n\"), String1(\"n\") … String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\"), String1(\"f\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\"), String1(\"s\")], CategoricalArrays.CategoricalValue{String1, UInt32}[String1(\"f\"), String1(\"y\"), String1(\"a\"), String1(\"c\"), String1(\"f\"), String1(\"n\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\") … String1(\"f\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"f\"), String1(\"y\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\")])","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"⚠️ Always split the data before oversampling. If your test data has oversampled observations then train-test contamination has occurred; novel observations will not come from the oversampling function.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Oversampling","page":"SMOTEN on Mushroom Data","title":"Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"It was obvious from the bar charts that there is a severe imbalance problem. Let's look at that again.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"m: ▇ 36 (1.0%) \nc: ▇▇▇ 192 (5.4%) \np: ▇▇▇▇ 256 (7.3%) \na: ▇▇▇▇▇▇ 400 (11.3%) \nl: ▇▇▇▇▇▇ 400 (11.3%) \ny: ▇▇▇▇▇▇▇▇ 576 (16.3%) \ns: ▇▇▇▇▇▇▇▇ 576 (16.3%) \nf: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 2160 (61.2%) \nn: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 3528 (100.0%)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Let's set our desired ratios as follows. these are set relative to the size of the majority class.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"ratios = Dict(\"m\"=>0.3, \n \"c\"=>0.4,\n \"p\"=>0.5,\n \"a\"=>0.5,\n \"l\"=>0.5,\n \"y\"=>0.7,\n \"s\"=>0.7,\n \"f\"=>0.8\n )","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Dict{String, Float64} with 8 entries:\n \"s\" => 0.7\n \"f\" => 0.8\n \"c\" => 0.4\n \"m\" => 0.3\n \"l\" => 0.5\n \"a\" => 0.5\n \"p\" => 0.5\n \"y\" => 0.7","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We have used gut feeling to set them here but usually this is one of the most important hyperparameters to tune over. ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"The easy option ratios=1.0 always exists and would mean that we want to oversample data in each class so that they all match the majority class. It may or may not be the most optimal due to overfitting problems.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Xover, yover = smoten(X_train, y_train; k=2, ratios=ratios, rng=Random.Xoshiro(42))","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Progress: 22%|█████████▏ | ETA: 0:00:01\u001b[K\n\u001b[A\nProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[K\n\u001b[A\n\n\n(15239×22 DataFrame\n Row │ class cap-shape cap-surface cap-color bruises gill-attachment g ⋯\n │ Cat… Cat… Cat… Cat… Cat… Cat… C ⋯\n───────┼────────────────────────────────────────────────────────────────────────\n 1 │ p f s e f f c ⋯\n 2 │ p f y e f f c\n 3 │ e f f w f f w\n 4 │ p f s e f f c\n 5 │ p f y e f f c ⋯\n 6 │ e s f g f f c\n 7 │ p f s n f f c\n 8 │ e x y g t f c\n ⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱\n 15233 │ p x y c f a c ⋯\n 15234 │ p x y e f a c\n 15235 │ p x y n f a c\n 15236 │ p k y c f f c\n 15237 │ p x y c f a c ⋯\n 15238 │ p k y c f f c\n 15239 │ p x y e f f c\n 16 columns and 15224 rows omitted, CategoricalArrays.CategoricalValue{String1, UInt32}[String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"n\"), String1(\"n\"), String1(\"n\") … String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\")])","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"SMOTEN uses a very specialized distance metric to decide the nearest neighbors which explains why it may be a bit slow as it's nontrivial to optimize KNN over such metric.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Now let's check the balance of the data","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"checkbalance(yover)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"m: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 847 (30.0%) \nc: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1129 (40.0%) \na: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1411 (50.0%) \nl: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1411 (50.0%) \np: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1411 (50.0%) \ny: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1975 (70.0%) \ns: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1975 (70.0%) \nf: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 2258 (80.0%) \nn: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 2822 (100.0%)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Training-the-Model","page":"SMOTEN on Mushroom Data","title":"Training the Model","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"ms = models(matching(Xover, yover))","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"6-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = OneRuleClassifier, package_name = OneRule, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Let's go for a OneRuleClassifier","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"import Pkg; Pkg.add(\"OneRule\")","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":" Resolving package versions...\n Installed MLJBalancing ─ v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Project.toml`\n [45f359ea] + MLJBalancing v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Manifest.toml`\n [45f359ea] + MLJBalancing v0.1.0\nPrecompiling project...\n ✓ MLJBalancing\n 1 dependency successfully precompiled in 25 seconds. 262 already precompiled.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Before-Oversampling","page":"SMOTEN on Mushroom Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# 1. Load the model\nOneRuleClassifier= @load OneRuleClassifier pkg=OneRule\n\n# 2. Instantiate it\nmodel = OneRuleClassifier()\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"import OneRule ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\n\ntrained Machine; caches model-specific representations of data\n model: OneRuleClassifier()\n args: \n 1:\tSource @978 ⏎ Table{Union{AbstractVector{Multiclass{10}}, AbstractVector{Multiclass{12}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{1}}, AbstractVector{Multiclass{4}}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{9}}, AbstractVector{Multiclass{6}}, AbstractVector{Multiclass{7}}}}\n 2:\tSource @097 ⏎ AbstractVector{Multiclass{9}}","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#After-Oversampling","page":"SMOTEN on Mushroom Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"trained Machine; caches model-specific representations of data\n model: OneRuleClassifier()\n args: \n 1:\tSource @469 ⏎ Table{Union{AbstractVector{Multiclass{10}}, AbstractVector{Multiclass{12}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{1}}, AbstractVector{Multiclass{4}}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{9}}, AbstractVector{Multiclass{6}}, AbstractVector{Multiclass{7}}}}\n 2:\tSource @942 ⏎ AbstractVector{Multiclass{9}}","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Evaluating-the-Model","page":"SMOTEN on Mushroom Data","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"To evaluate the model, we will use the balanced accuracy metric which equally account for all classes.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Before-Oversampling-2","page":"SMOTEN on Mushroom Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"y_pred = MLJ.predict(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"0.22","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#After-Oversampling-2","page":"SMOTEN on Mushroom Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"y_pred_over = MLJ.predict(mach_over, X_test)\n\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"0.4","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Evaluating-the-Model-Revisited","page":"SMOTEN on Mushroom Data","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a full blown 18% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Before-Oversampling-3","page":"SMOTEN on Mushroom Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬───────────┬─────────────┬──────────┬────────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼───────────┼─────────────┼──────────┼────────────────────\n│ BalancedAccuracy( │ predict │ 0.218 │ 0.000718 │ [0.218, 0.218, 0. ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴───────────┴─────────────┴──────────┴────────────────────\n 1 column omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Before oversampling, and assuming that the balanced accuracy score is normally distribued we can be 95% confident that the balanced accuracy on new data is 21.8±0.07. This is a better estimate than the 20% figure we had earlier.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#After-Oversampling-3","page":"SMOTEN on Mushroom Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.SMOTEN(k=2, ratios=ratios, rng=Random.Xoshiro(42))\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Progress: 22%|█████████▏ | ETA: 0:00:01\u001b[K\n\u001b[A\nProgress: 56%|██████████████████████▊ | ETA: 0:00:00\u001b[K\nProgress: 22%|█████████▏ | ETA: 0:00:00\u001b[K\n\u001b[A\nProgress: 78%|███████████████████████████████▉ | ETA: 0:00:00\u001b[K\n\u001b[A\n\n\ntrained Machine; does not cache data\n model: BalancedModelDeterministic(balancers = Imbalance.MLJ.SMOTEN{Dict{String, Float64}, Xoshiro}[SMOTEN(k = 2, …)], …)\n args: \n 1:\tSource @692 ⏎ Table{Union{AbstractVector{Multiclass{10}}, AbstractVector{Multiclass{12}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{1}}, AbstractVector{Multiclass{4}}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{9}}, AbstractVector{Multiclass{6}}, AbstractVector{Multiclass{7}}}}\n 2:\tSource @468 ⏎ AbstractVector{Multiclass{9}}","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":" y_pred_over == predict(mach_over, X_test)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"true","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Now let's cross-validate","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"cv=CV(nfolds=10)\ne = evaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"e","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"PerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬───────────┬─────────────┬─────────┬─────────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼───────────┼─────────────┼─────────┼─────────────────────\n│ BalancedAccuracy( │ predict │ 0.4 │ 0.00483 │ [0.398, 0.405, 0.3 ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴───────────┴─────────────┴─────────┴─────────────────────\n 1 column omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Fair enough. After oversampling the interval under the same assumptions is 40±0.5%; this agrees with our earlier observations using simple point estimates; oversampling here approximately delivers a 18% improvement in balanced accuracy.","category":"page"},{"location":"algorithms/implementation_notes/#Generalizing-to-Multiclass","page":"Implementation Notes","title":"Generalizing to Multiclass","text":"","category":"section"},{"location":"algorithms/implementation_notes/","page":"Implementation Notes","title":"Implementation Notes","text":"Papers often propose the resampling algorithm for the case of binary classification only. In many cases, the algorithm only expects a set of points to resample and has nothing to do with the existence of a majority class (e.g., estimates the distribution of points then generates new samples from it) so it can be generalized by simply applying the algorithm on each class. In other cases, there is an interaction with the majority class (e.g., a point is borderline in BorderlineSMOTE1 if the majority but not all its neighbors are from the majority class). In this case, a one-vs-rest scheme is used as proposed in [1]. For instance, a point is now borderline if most but not all its neighbors are from a different class. ","category":"page"},{"location":"algorithms/implementation_notes/#Generalizing-to-Real-Ratios","page":"Implementation Notes","title":"Generalizing to Real Ratios","text":"","category":"section"},{"location":"algorithms/implementation_notes/","page":"Implementation Notes","title":"Implementation Notes","text":"Papers often proposes the resampling algorithm using integer ratios. For instance, a ratio of 2 would mean to double the amount of data in a class and a ratio of 22 is not allowed or will be rounded. In Imbalance.jl any appropriate real ratio can be used and the ratio is relative to the size of the majority or minority class depending on whether the algorithm is oversampling or undersampling. The generalization occurs by randomly choosing points instead of looping on each point. That is, if a 22 ratio corresponds to 227 examples then 227 examples are chosen randomly by replacement then applying resampling logic to each. Given an integer ratio k, this falls back to be on average equivalent to looping on the points k times.","category":"page"},{"location":"algorithms/implementation_notes/","page":"Implementation Notes","title":"Implementation Notes","text":"[1] Fernández, A., López, V., Galar, M., Del Jesus, M. J., and Herrera, F. (2013). Analysing the classifi- cation of imbalanced data-sets with multiple classes: Binarization techniques and ad-hoc approaches. Knowledge-Based Systems, 42:97–110.","category":"page"},{"location":"algorithms/mlj_balancing/#Combining-Resamplers","page":"Combination","title":"Combining Resamplers","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"Resampling methods can be combined sequentially or in parallel, along with a classification model, to yield hybrid or ensemble models that may be even more powerful than using the classification model with only one of the individual resamplers.","category":"page"},{"location":"algorithms/mlj_balancing/#Sequential-Resampling","page":"Combination","title":"Sequential Resampling","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"MLJBalancing.jl allows chaining an arbitrary number of resamplers from Imbalance.jl (also called balancers) with classification models from MLJ via BalancedModel. This makes it possible to use BalancedModel to form hybrid resampling methods that combine oversampling and under-sampling methods in a linear pipeline such as SMOTE-Tomek and SMOTE-ENN.","category":"page"},{"location":"algorithms/mlj_balancing/#Construct-the-resampler-and-classification-models","page":"Combination","title":"Construct the resampler and classification models","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"SMOTE = @load SMOTE pkg=Imbalance verbosity=0\nTomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n\noversampler = SMOTE(k=5, ratios=1.0, rng=42)\nundersampler = TomekUndersampler(min_ratios=0.5, rng=42)\n\nlogistic_model = LogisticClassifier()","category":"page"},{"location":"algorithms/mlj_balancing/#Wrap-them-all-in-BalancedModel","page":"Combination","title":"Wrap them all in BalancedModel","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"balanced_model = BalancedModel(model=logistic_model, \n balancer1=oversampler, balancer2=undersampler)","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"Here training data will be passed to balancer1 then balancer2, whose output is used to train the classifier model. In prediction, the resamplers balancer1 and blancer2 are bypassed and in general. At this point, they behave like one single MLJ model that can be fit, validated or fine-tuned like any other.","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"In general, there can be any number of balancers, and the user can give the balancers arbitrary names. ","category":"page"},{"location":"algorithms/mlj_balancing/#Parallel-Resampling-with-Balanced-Bagging","page":"Combination","title":"Parallel Resampling with Balanced Bagging","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"MLJBalancing.jl also offers an implementation of bagging over probabilistic classifiers where the majority class is randomly undersampled T times down to the size of the minority class then a model is trained on each of the T undersampled datasets. The predictions are then aggregated by averaging. This is offered via BalancedBaggingClassifier and can be only used for binary classification.","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"BalancedBaggingClassifier(model=nothing, T=0, rng = Random.default_rng(),)","category":"page"},{"location":"algorithms/mlj_balancing/#Arguments","page":"Combination","title":"Arguments","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"model::Probabilistic: A probabilistic classification model that implements the MLJModelInterface\nT::Integer=0: The number of bags to be used in the ensemble. If not given, will be set as the ratio between the frequency of the majority and minority classes.\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer ","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"seed to be used with Xoshiro","category":"page"},{"location":"algorithms/mlj_balancing/#Example","page":"Combination","title":"Example","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"using MLJ\nusing Imbalance\nusing MLJBalancing\n\nX, y = generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], \n probs = [0.9, 0.1], \n type = \"ColTable\", \n rng=42)\n\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\nlogistic_model = LogisticClassifier()\nbagging_model = BalancedBaggingClassifier(model=logistic_model, T=10, rng=Random.Xoshiro(42))","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"Now you can fit, predict, cross-validate and finetune it like any other probabilistic MLJ model where X must be a table input (e.g., a dataframe).","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"mach = machine(bagging_model, X, y)\nfit!(mach)\npred = predict(mach, X)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#From-RandomOversampling-to-ROSE","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"\nimport Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\",\n \"ScientificTypes\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing ScientificTypes\nusing Imbalance\nusing Plots, Measures\nusing HTTP: download","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Loading-Data","page":"From RandomOversampling to ROSE","title":"Loading Data","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Let's load the Iris dataset, the objective of this dataset is to predict the type of flower as one of \"virginica\", \"versicolor\" and \"setosa\" using its sepal and petal length and width. ","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"We don't need to so from a CSV file this time because MLJ has a macro for loading it already! The only difference is that we will need to explictly convert it to a dataframe as MLJ loads it as a named tuple of vectors.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"X, y = @load_iris\nX = DataFrame(X)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"┌──────────────┬─────────────┬──────────────┬─────────────┐\n│ sepal_length │ sepal_width │ petal_length │ petal_width │\n│ Float64 │ Float64 │ Float64 │ Float64 │\n│ Continuous │ Continuous │ Continuous │ Continuous │\n├──────────────┼─────────────┼──────────────┼─────────────┤\n│ 5.1 │ 3.5 │ 1.4 │ 0.2 │\n│ 4.9 │ 3.0 │ 1.4 │ 0.2 │\n│ 4.7 │ 3.2 │ 1.3 │ 0.2 │\n│ 4.6 │ 3.1 │ 1.5 │ 0.2 │\n│ 5.0 │ 3.6 │ 1.4 │ 0.2 │\n└──────────────┴─────────────┴──────────────┴─────────────┘","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Our purpose for this tutorial is primarily visuallization. Thus, let's select two of the continuous features only to work with. It's known that the sepal length and width play a much bigger role in classifying the type of flower so let's keep those only.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"X = select(X, :petal_width, :petal_length)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"┌─────────────┬──────────────┐\n│ petal_width │ petal_length │\n│ Float64 │ Float64 │\n│ Continuous │ Continuous │\n├─────────────┼──────────────┤\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.3 │\n│ 0.2 │ 1.5 │\n│ 0.2 │ 1.4 │\n└─────────────┴──────────────┘","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Coercing-Data","page":"From RandomOversampling to ROSE","title":"Coercing Data","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"ScientificTypes.schema(X)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"┌──────────────┬────────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────┼────────────┼─────────┤\n│ petal_width │ Continuous │ Float64 │\n│ petal_length │ Continuous │ Float64 │\n└──────────────┴────────────┴─────────┘","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Things look good, no coercion is needed.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Oversampling","page":"From RandomOversampling to ROSE","title":"Oversampling","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Iris, by default has no imbalance problem","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"checkbalance(y)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"virginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"To simulate that there is a balance problem, we will consider a random sample of 100 observations. A random sample does not guarantee perserving the proportion of classes; in this, we actually set the seed to get a very unlikely random sample that suffers from strong imbalance.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Random.seed!(803429)\nsubset_indices = rand(1:size(X, 1), 100)\nX, y = X[subset_indices, :], y[subset_indices]\ncheckbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"versicolor: ▇▇▇▇▇▇▇▇▇▇▇ 12 (22.6%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 35 (66.0%) \nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"We will treat this as our training set going forward so we don't need to partition. Now let's oversample it with ROSE.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Xover, yover = rose(X, y; s=0.3, ratios=Dict(\"versicolor\" => 1.0, \"setosa\"=>1.0))\ncheckbalance(yover)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Progress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[K\n\u001b[A\n\nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Training-the-Model","page":"From RandomOversampling to ROSE","title":"Training the Model","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"53-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Let's go for a BayesianLDA.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"import Pkg; Pkg.add(\"MultivariateStats\")","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Before-Oversampling","page":"From RandomOversampling to ROSE","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"# 1. Load the model\nBayesianLDA = @load BayesianLDA pkg=MultivariateStats\n\n# 2. Instantiate it \nmodel = BayesianLDA()\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X, y)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#After-Oversampling","page":"From RandomOversampling to ROSE","title":"After Oversampling","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Plot-Decision-Boundaries","page":"From RandomOversampling to ROSE","title":"Plot Decision Boundaries","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Construct ranges for each feature and consecutively a grid","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"petal_width_range =\n\trange(minimum(X.petal_width) - 1, maximum(X.petal_width) + 1, length = 200)\npetal_length_range =\n\trange(minimum(X.petal_length) - 1, maximum(X.petal_length) + 1, length = 200)\ngrid_points = [(pw, pl) for pw in petal_width_range, pl in petal_length_range]\n","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"200×200 Matrix{Tuple{Float64, Float64}}:\n (-0.9, 0.2) (-0.9, 0.238693) … (-0.9, 7.9)\n (-0.878894, 0.2) (-0.878894, 0.238693) (-0.878894, 7.9)\n (-0.857789, 0.2) (-0.857789, 0.238693) (-0.857789, 7.9)\n (-0.836683, 0.2) (-0.836683, 0.238693) (-0.836683, 7.9)\n (-0.815578, 0.2) (-0.815578, 0.238693) (-0.815578, 7.9)\n (-0.794472, 0.2) (-0.794472, 0.238693) … (-0.794472, 7.9)\n (-0.773367, 0.2) (-0.773367, 0.238693) (-0.773367, 7.9)\n (-0.752261, 0.2) (-0.752261, 0.238693) (-0.752261, 7.9)\n (-0.731156, 0.2) (-0.731156, 0.238693) (-0.731156, 7.9)\n (-0.71005, 0.2) (-0.71005, 0.238693) (-0.71005, 7.9)\n ⋮ ⋱ \n (3.13116, 0.2) (3.13116, 0.238693) (3.13116, 7.9)\n (3.15226, 0.2) (3.15226, 0.238693) (3.15226, 7.9)\n (3.17337, 0.2) (3.17337, 0.238693) (3.17337, 7.9)\n (3.19447, 0.2) (3.19447, 0.238693) (3.19447, 7.9)\n (3.21558, 0.2) (3.21558, 0.238693) … (3.21558, 7.9)\n (3.23668, 0.2) (3.23668, 0.238693) (3.23668, 7.9)\n (3.25779, 0.2) (3.25779, 0.238693) (3.25779, 7.9)\n (3.27889, 0.2) (3.27889, 0.238693) (3.27889, 7.9)\n (3.3, 0.2) (3.3, 0.238693) (3.3, 7.9)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Evaluate the grid with the machine before and after oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"grid_predictions = [\n\tpredict_mode(mach, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\tpoint in grid_points\n]\ngrid_predictions_over = [\n\tpredict_mode(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\tpoint in grid_points\n]","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"200×200 CategoricalArrays.CategoricalArray{String,2,UInt32}:\n \"setosa\" \"setosa\" \"setosa\" … \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" … \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n ⋮ ⋱ \n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" … \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Make two contour plots using the grid predictions before and after oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"p = contourf(petal_length_range, petal_width_range, grid_predictions,\n\tlevels = 3, color = :Set3_3, colorbar = false)\np_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n\tlevels = 3, color = :Set3_3, colorbar = false)\nprintln()","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Scatter plot the data before and after oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"old_count = size(X, 1)\nlabels = unique(y)\ncolors = Dict(\"setosa\" => \"green\", \"versicolor\" => \"yellow\",\n\t\"virginica\" => \"purple\")\n\nfor label in labels\n\tscatter!(p, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"Before Oversampling\")\n\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"After Oversampling\")\n\t# find new points only and plot with different shape\n\tscatter!(p_over, Xover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\tcolor = colors[label], label = label*\"-over\", markershape = :diamond,\n\t\ttitle = \"After Oversampling\")\nend\n\nplot_res = plot(\n\tp,\n\tp_over,\n\tlayout = (1, 2),\n\txlabel = \"petal length\",\n\tylabel = \"petal width\",\n\tsize = (900, 300),\n\tmargin = 5mm, dpi = 200\n)\nsavefig(plot_res, \"./assets/ROSE-before-after.png\")","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"(Image: Before After ROSE)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Effect-of-Increasing-s","page":"From RandomOversampling to ROSE","title":"Effect of Increasing s","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"anim = @animate for s ∈ 0:0.03:6.0\n\t# oversample\n\tXover, yover =\n\t\trose(X, y; s = s, ratios = Dict(\"setosa\" => 1.0, \"versicolor\" => 1.0), rng = 42)\n\n\tmodel = BayesianLDA()\n\tmach_over = machine(model, Xover, yover)\n\tfit!(mach_over, verbosity = 0)\n\n\t# grid predictions\n\tgrid_predictions_over = [\n\t\tpredict_mode(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\t\tpoint in grid_points\n\t]\n\n\tp_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n\t\tlevels = 3, color = :Set3_3, colorbar = false)\n\n\told_count = size(X, 1)\n\tfor label in labels\n\t\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\t\tcolor = colors[label], label = label,\n\t\t\ttitle = \"Oversampling with s=$s\")\n\t\t# find new points only and plot with different shape\n\t\tscatter!(p_over,\n\t\t\tXover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tcolor = colors[label], label = label * \"-over\", markershape = :diamond,\n\t\t\ttitle = \"Oversampling with s=$s\")\n\tend\n\tplot!(dpi = 150)\nend\n","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"gif(anim, \"./assets/rose-animation.gif\", fps=6)\nprintln()","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"(Image: ROSE Effect of S)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"As we can see, the larger s is the more spread out are the oversampled points. This is expected because what ROSE does is oversample by sampling from the distribution that corresponds to placing Gaussians on the existing points and s is a hyperparameter proportional to the bandwidth of the Gaussians. When s=0 the only points that can be generated lie on top of others; i.e., ROSE becomes equivalent to random oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"The decision boundary is mainly unstable because we used a small number of epochs with the perceptron to generate this animation. It still took plenty of time.","category":"page"},{"location":"about/#Credits","page":"About","title":"Credits","text":"","category":"section"},{"location":"about/","page":"About","title":"About","text":"This package was created by Essam Wisam as a Google Summer of Code project, under the mentorship of Anthony Blaom. Special thanks also go to Rik Huijzer for his friendliness and the binary SMOTE implementation in Resample.jl.","category":"page"},{"location":"algorithms/undersampling_algorithms/#Undersampling-Algorithms","page":"Undersampling","title":"Undersampling Algorithms","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"The following table portrays the supported undersampling algorithms, whether the mechanism deletes or generates new data and the supported types of data.","category":"page"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"Undersampling Method Mechanism Supported Data Types\nRandom Undersampler Delete existing data as needed Continuous and/or nominal\nCluster Undersampler Generate new data or delete existing data Continuous\nEdited Nearest Neighbors Undersampler Delete existing data meeting certain conditions (cleaning) Continuous\nTomek Links Undersampler Delete existing data meeting certain conditions (cleaning) Continuous","category":"page"},{"location":"algorithms/undersampling_algorithms/#Random-Undersampler","page":"Undersampling","title":"Random Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"random_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.random_undersample","page":"Undersampling","title":"Imbalance.random_undersample","text":"random_undersample(\n X, y; \n ratios=1.0, rng=default_rng(), \n try_preserve_type=true\n)\n\nDescription\n\nNaively undersample a dataset by randomly deleting existing observations.\n\nPositional Arguments\n\nX: A matrix of real numbers or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nratios=1.0: A parameter that controls the amount of undersampling to be done for each class\nCan be a float and in this case each class will be undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n 1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n 2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n 0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply randomundersampling\nX_under, y_under = random_undersample(X, y; ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), \n rng=42)\n \njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the RandomUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nRandomUndersampler = @load RandomUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = RandomUndersampler(ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), \n rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate Random Undersampler model\nundersampler = RandomUndersampler(y_ind; ratios=Dict(0=>1.0, 1=>1.0, 2=>1.0), rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/undersampling_algorithms/#Cluster-Undersampler","page":"Undersampling","title":"Cluster Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"cluster_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.cluster_undersample","page":"Undersampling","title":"Imbalance.cluster_undersample","text":"cluster_undersample(\n X, y; \n mode= \"nearest\", ratios = 1.0, maxiter = 100,\n rng=default_rng(), try_preserve_type=true\n)\n\nDescription\n\nUndersample a dataset using clustering undersampling as presented in [1] using K-means as the clustering algorithm.\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nmode::AbstractString=\"nearest: If \"center\" then the undersampled data will consist of the centriods of each cluster found; meanwhile, if \"nearest\" then it will consist of the nearest neighbor of each centroid.\nratios=1.0: A parameter that controls the amount of undersampling to be done for each class\nCan be a float and in this case each class will be undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nmaxiter::Integer=100: Maximum number of iterations to run K-means\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n \njulia> Imbalance.checkbalance(y; ref=\"minority\")\n 1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n 2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n 0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply cluster_undersampling\nX_under, y_under = cluster_undersample(X, y; mode=\"nearest\", \n ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), rng=42)\n \njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the ClusterUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nClusterUndersampler = @load ClusterUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = ClusterUndersampler(mode=\"nearest\", \n ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate ClusterUndersampler model\nundersampler = ClusterUndersampler(y_ind; mode=\"nearest\", \n ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Wei-Chao, L., Chih-Fong, T., Ya-Han, H., & Jing-Shang, J. (2017). Clustering-based undersampling in class-imbalanced data. Information Sciences, 409–410, 17–26.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/undersampling_algorithms/#Edited-Nearest-Neighbors-Undersampler","page":"Undersampling","title":"Edited Nearest Neighbors Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"enn_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.enn_undersample","page":"Undersampling","title":"Imbalance.enn_undersample","text":"enn_undersample(\n X, y; k = 5, keep_condition = \"mode\",\n min_ratios = 1.0, force_min_ratios = false,\n rng = default_rng(), try_preserve_type=true\n)\n\nDescription\n\nUndersample a dataset by removing points that violate a certain condition such as belonging to a different class compared to the majority of the neighbors, as proposed in [1].\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the data. It will be automatically set to n-1 if n ≤ k.\n\nkeep_condition::AbstractString=\"mode\": The condition that leads to removing a point upon violation. Takes one of \"exists\", \"mode\", \"only mode\" and \"all\"\n\"exists\": the point has at least one neighbor from the same class\n\"mode\": the class of the point is one of the most frequent classes of the neighbors (there may be many)\n\"only mode\": the class of the point is the single most frequent class of the neighbors\n\"all\": the class of the point is the same as all the neighbors\n\nmin_ratios=1.0: A parameter that controls the maximum amount of undersampling to be done for each class. If this algorithm cleans the data to an extent that this is violated, some of the cleaned points will be revived randomly so that it is satisfied.\nCan be a float and in this case each class will be at most undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float minimum ratio for that class\n\nforce_min_ratios=false: If true, and this algorithm cleans the data such that the ratios for each class exceed those specified in min_ratios then further undersampling will be perform so that the final ratios are equal to min_ratios.\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42)\n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply enn undersampling\nX_under, y_under = enn_undersample(X, y; k=3, keep_condition=\"only mode\", \n min_ratios=0.5, rng=42)\n\njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 10 (100.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 10 (100.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 24 (240.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the ENNUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nENNUndersampler = @load ENNUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = ENNUndersampler(k=3, keep_condition=\"only mode\", min_ratios=0.5, rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42)\n\n# Initiate ENN Undersampler model\nundersampler = ENNUndersampler(y_ind; k=3, keep_condition=\"only mode\", rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Dennis L Wilson. Asymptotic properties of nearest neighbor rules using edited data. \tIEEE Transactions on Systems, Man, and Cybernetics, pages 408–421, 1972.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/undersampling_algorithms/#Tomek-Links-Undersampler","page":"Undersampling","title":"Tomek Links Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"tomek_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.tomek_undersample","page":"Undersampling","title":"Imbalance.tomek_undersample","text":"tomek_undersample(\n X, y;\n min_ratios = 1.0, force_min_ratios = false,\n rng = default_rng(), try_preserve_type=true\n)\n\nDescription\n\nUndersample a dataset by removing (\"cleaning\") any point that is part of a tomek link in the data. \tTomek links are presented in [1].\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nmin_ratios=1.0: A parameter that controls the maximum amount of undersampling to be done for each class. If this algorithm cleans the data to an extent that this is violated, some of the cleaned points will be revived randomly so that it is satisfied.\nCan be a float and in this case each class will be at most undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float minimum ratio for that class\n\nforce_min_ratios=false: If true, and this algorithm cleans the data such that the ratios for each class exceed those specified in min_ratios then further undersampling will be perform so that the final ratios are equal to min_ratios.\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply tomek undersampling\nX_under, y_under = tomek_undersample(X, y; min_ratios=1.0, rng=42)\n\njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 22 (115.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 36 (189.5%)\n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the TomekUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nTomekUndersampler = @load TomekUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = TomekUndersampler(min_ratios=1.0, rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42) \n\n# Initiate TomekUndersampler model\nundersampler = TomekUndersampler(y_ind; min_ratios=1.0, rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Ivan Tomek. Two modifications of cnn. IEEE Trans. Systems, Man and Cybernetics, 6:769–772, 1976.\n\n\n\n\n\n","category":"function"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Effect-of-ratios-Hyperparameter","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"using Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing ScientificTypes\nusing Imbalance\nusing Plots, Measures","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Loading-Data","page":"Effect of ratios Hyperparameter","title":"Loading Data","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Let's load the Iris dataset, the objective of this dataset is to predict the type of flower as one of \"virginica\", \"versicolor\" and \"setosa\" using its sepal and petal length and width.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"We don't need to so from a CSV file this time because MLJ has a macro for loading it already! The only difference is that we will need to explictly convert it to a dataframe as MLJ loads it as a named tuple of vectors.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"X, y = @load_iris\nX = DataFrame(X)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌──────────────┬─────────────┬──────────────┬─────────────┐\n│ sepal_length │ sepal_width │ petal_length │ petal_width │\n│ Float64 │ Float64 │ Float64 │ Float64 │\n│ Continuous │ Continuous │ Continuous │ Continuous │\n├──────────────┼─────────────┼──────────────┼─────────────┤\n│ 5.1 │ 3.5 │ 1.4 │ 0.2 │\n│ 4.9 │ 3.0 │ 1.4 │ 0.2 │\n│ 4.7 │ 3.2 │ 1.3 │ 0.2 │\n│ 4.6 │ 3.1 │ 1.5 │ 0.2 │\n│ 5.0 │ 3.6 │ 1.4 │ 0.2 │\n└──────────────┴─────────────┴──────────────┴─────────────┘","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Our purpose for this tutorial is primarily visuallization. Thus, let's select two of the continuous features only to work with. It's known that the sepal length and width play a much bigger role in classifying the type of flower so let's keep those only.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"X = select(X, :petal_width, :petal_length)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌─────────────┬──────────────┐\n│ petal_width │ petal_length │\n│ Float64 │ Float64 │\n│ Continuous │ Continuous │\n├─────────────┼──────────────┤\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.3 │\n│ 0.2 │ 1.5 │\n│ 0.2 │ 1.4 │\n└─────────────┴──────────────┘","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Coercing-Data","page":"Effect of ratios Hyperparameter","title":"Coercing Data","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"ScientificTypes.schema(X)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌──────────────┬────────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────┼────────────┼─────────┤\n│ petal_width │ Continuous │ Float64 │\n│ petal_length │ Continuous │ Float64 │\n└──────────────┴────────────┴─────────┘","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Things look good, no coercion is needed.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Oversampling","page":"Effect of ratios Hyperparameter","title":"Oversampling","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Iris, by default has no imbalance problem","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"checkbalance(y)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"virginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"To simulate that there is a balance problem, we will consider a random sample of 100 observations. A random sample does not guarantee perserving the proportion of classes; in this, we actually set the seed to get a very unlikely random sample that suffers from moderate imbalance.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Random.seed!(803429)\nsubset_indices = rand(1:size(X, 1), 100)\nX, y = X[subset_indices, :], y[subset_indices]\ncheckbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"versicolor: ▇▇▇▇▇▇▇▇▇▇▇ 12 (22.6%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 35 (66.0%) \nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"We will treat this as our training set going forward so we don't need to partition. Now let's oversample it with SMOTE.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Xover, yover = smote(X, y; k=5, ratios=Dict(\"versicolor\" => 0.7), rng=42)\ncheckbalance(yover)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"setosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 35 (66.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 37 (69.8%) \nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Training-the-Model","page":"Effect of ratios Hyperparameter","title":"Training the Model","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"53-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Let's go for an SVM","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"import Pkg;\nPkg.add(\"MLJLIBSVMInterface\");","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Before-Oversampling","page":"Effect of ratios Hyperparameter","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"# 1. Load the model\nSVC = @load SVC pkg = LIBSVM\n\n# 2. Instantiate it (γ=0.01 is intentional)\nmodel = SVC(gamma=0.01)\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X, y)\n\n# 4. fit the machine learning model\nfit!(mach)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\nimport MLJLIBSVMInterface ✔\n\n\n┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @527 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @580 ⏎ AbstractVector{Multiclass{3}}","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#After-Oversampling","page":"Effect of ratios Hyperparameter","title":"After Oversampling","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @277 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @977 ⏎ AbstractVector{Multiclass{3}}","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Plot-Decision-Boundaries","page":"Effect of ratios Hyperparameter","title":"Plot Decision Boundaries","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Construct ranges for each feature and consecutively a grid","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"petal_width_range =\n\trange(minimum(X.petal_width) - 1, maximum(X.petal_width) + 1, length = 200)\npetal_length_range =\n\trange(minimum(X.petal_length) - 1, maximum(X.petal_length) + 1, length = 200)\ngrid_points = [(pw, pl) for pw in petal_width_range, pl in petal_length_range]","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"200×200 Matrix{Tuple{Float64, Float64}}:\n (-0.9, 0.2) (-0.9, 0.238693) … (-0.9, 7.9)\n (-0.878894, 0.2) (-0.878894, 0.238693) (-0.878894, 7.9)\n (-0.857789, 0.2) (-0.857789, 0.238693) (-0.857789, 7.9)\n (-0.836683, 0.2) (-0.836683, 0.238693) (-0.836683, 7.9)\n (-0.815578, 0.2) (-0.815578, 0.238693) (-0.815578, 7.9)\n (-0.794472, 0.2) (-0.794472, 0.238693) … (-0.794472, 7.9)\n (-0.773367, 0.2) (-0.773367, 0.238693) (-0.773367, 7.9)\n (-0.752261, 0.2) (-0.752261, 0.238693) (-0.752261, 7.9)\n (-0.731156, 0.2) (-0.731156, 0.238693) (-0.731156, 7.9)\n (-0.71005, 0.2) (-0.71005, 0.238693) (-0.71005, 7.9)\n ⋮ ⋱ \n (3.13116, 0.2) (3.13116, 0.238693) (3.13116, 7.9)\n (3.15226, 0.2) (3.15226, 0.238693) (3.15226, 7.9)\n (3.17337, 0.2) (3.17337, 0.238693) (3.17337, 7.9)\n (3.19447, 0.2) (3.19447, 0.238693) (3.19447, 7.9)\n (3.21558, 0.2) (3.21558, 0.238693) … (3.21558, 7.9)\n (3.23668, 0.2) (3.23668, 0.238693) (3.23668, 7.9)\n (3.25779, 0.2) (3.25779, 0.238693) (3.25779, 7.9)\n (3.27889, 0.2) (3.27889, 0.238693) (3.27889, 7.9)\n (3.3, 0.2) (3.3, 0.238693) (3.3, 7.9)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Evaluate the grid with the machine before and after oversampling","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"grid_predictions =[\n predict(mach, Tables.table(reshape(collect(point), 1, 2)))[1] for\n \tpoint in grid_points\n ]\ngrid_predictions_over = [\n predict(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n point in grid_points\n]","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"200×200 CategoricalArrays.CategoricalArray{String,2,UInt32}:\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" … \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" … \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n ⋮ ⋱ \n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" … \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Make two contour plots using the grid predictions before and after oversampling","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"p = contourf(petal_length_range, petal_width_range, grid_predictions,\n levels=3, color=:Set3_3, colorbar=false)\np_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n levels=3, color=:Set3_3, colorbar=false)\nprintln()","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Scatter plot the data before and after oversampling","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"old_count = size(X, 1)\n\ncolors = Dict(\"setosa\" => \"green\", \"versicolor\" => \"yellow\",\n\t\"virginica\" => \"purple\")\nlabels = unique(y)\nfor label in labels\n\tscatter!(p, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"Before Oversampling\")\n\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"After Oversampling\")\n\t# find new points only and plot with different shape\n\tscatter!(p_over, Xover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\tcolor = colors[label], label = label*\"-over\", markershape = :hexagon,\n\t\ttitle = \"After Oversampling\")\nend\n\nplot_res = plot(p, p_over, layout = (1, 2), xlabel = \"petal length\",\n\tylabel = \"petal width\", size = (900, 300), margin = 5mm, dpi = 200)\nsavefig(plot_res, \"./assets/before-after-smote.png\")\nprintln()","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"(Image: Before After SMOTE)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Notice how the minority class was completely ignore prior to oversampling. Not all models and hyperparameter settings are this delicate to class imbalance.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Effect-of-Ratios-Hyperparameter","page":"Effect of ratios Hyperparameter","title":"Effect of Ratios Hyperparameter","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Now let's study the effect of the ratios hyperparameter. We will do this through an animated plot.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"anim = @animate for versicolor_ratio ∈ 0.3:0.01:2\n\t# oversample\n\tXover, yover =\n\t\tsmote(X, y; k = 5, ratios = Dict(\"versicolor\" => versicolor_ratio), rng = 42)\n\n\t# fit machine\n\tmodel = SVC(gamma = 0.01)\n\tmach_over = machine(model, Xover, yover)\n\tfit!(mach_over, verbosity = 0)\n\n\t# grid predictions\n\tgrid_predictions_over = [\n\t\tpredict(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\t\tpoint in grid_points\n\t]\n\t# plot\n\tp_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n\t\tlevels = 3, color = :Set3_3, colorbar = false)\n\told_count = size(X, 1)\n\tfor label in labels\n\t\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\t\tcolor = colors[label], label = label,\n\t\t\ttitle = \"Oversampling versicolor with ratio $versicolor_ratio\")\n\t\t# find new points only and plot with different shape\n\t\tscatter!(p_over,\n\t\t\tXover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tcolor = colors[label], label = label * \"-over\", markershape = :hexagon,\n\t\t\ttitle = \"Oversampling versicolor with ratio $versicolor_ratio\")\n\tend\n\tplot!(dpi = 150)\nend\n","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"gif(anim, \"./assets/smote-animation.gif\", fps=6)\nprintln()","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"(Image: Ratios Parameter Effect)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Notice how setting ratios greedily can lead to overfitting.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#SMOTE-on-Customer-Churn-Data","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"\nimport Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Imbalance\nusing MLJBalancing\nusing CSV\nusing DataFrames\nusing ScientificTypes\nusing CategoricalArrays\nusing MLJ\nusing Plots\nusing Random\nusing HTTP: download","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Loading-Data","page":"SMOTE on Customer Churn Data","title":"Loading Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"In this example, we will consider the Churn for Bank Customers found on Kaggle where the objective is to predict whether a customer is likely to leave a bank given financial and demographic features.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/smote_churn_dataset/churn.csv\", \"./\")\ndf = CSV.read(\"./churn.csv\", DataFrame)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌───────────┬────────────┬──────────┬─────────────┬───────────┬─────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ RowNumber │ CustomerId │ Surname │ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ Int64 │ String31 │ Int64 │ String7 │ String7 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Count │ Textual │ Count │ Textual │ Textual │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├───────────┼────────────┼──────────┼─────────────┼───────────┼─────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 1 │ 15634602 │ Hargrave │ 619 │ France │ Female │ 42 │ 2 │ 0.0 │ 1 │ 1 │ 1 │ 1.01349e5 │ 1 │\n│ 2 │ 15647311 │ Hill │ 608 │ Spain │ Female │ 41 │ 1 │ 83807.9 │ 1 │ 0 │ 1 │ 1.12543e5 │ 0 │\n│ 3 │ 15619304 │ Onio │ 502 │ France │ Female │ 42 │ 8 │ 1.59661e5 │ 3 │ 1 │ 0 │ 1.13932e5 │ 1 │\n│ 4 │ 15701354 │ Boni │ 699 │ France │ Female │ 39 │ 1 │ 0.0 │ 2 │ 0 │ 0 │ 93826.6 │ 0 │\n│ 5 │ 15737888 │ Mitchell │ 850 │ Spain │ Female │ 43 │ 2 │ 1.25511e5 │ 1 │ 1 │ 1 │ 79084.1 │ 0 │\n└───────────┴────────────┴──────────┴─────────────┴───────────┴─────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"There are plenty of useless columns that we can get rid of such as RowNumber and CustomerID. We also have to get rid of the categoircal features because SMOTE won't be able to deal with those; however, other variants such as SMOTE-NC can which we will consider in another tutorial.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"df = df[:, Not([:RowNumber, :CustomerId, :Surname, \n :Geography, :Gender])]\n\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌─────────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ CreditScore │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├─────────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 619.0 │ 42.0 │ 2.0 │ 0.0 │ 1.0 │ 1.0 │ 1.0 │ 1.01349e5 │ 1.0 │\n│ 608.0 │ 41.0 │ 1.0 │ 83807.9 │ 1.0 │ 0.0 │ 1.0 │ 1.12543e5 │ 0.0 │\n│ 502.0 │ 42.0 │ 8.0 │ 1.59661e5 │ 3.0 │ 1.0 │ 0.0 │ 1.13932e5 │ 1.0 │\n│ 699.0 │ 39.0 │ 1.0 │ 0.0 │ 2.0 │ 0.0 │ 0.0 │ 93826.6 │ 0.0 │\n│ 850.0 │ 43.0 │ 2.0 │ 1.25511e5 │ 1.0 │ 1.0 │ 1.0 │ 79084.1 │ 0.0 │\n└─────────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Ideally, we may even remove ordinal variables because SMOTE will treat them as continuous and the synthetic data it generates will taking floating point values which will not occur in future data. Some models may be robust to this whatsoever and the main purpose of this tutorial is to later compare SMOTE-NC with SMOTE.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Coercing-Data","page":"SMOTE on Customer Churn Data","title":"Coercing Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Let's coerce everything to continuous except for the target variable.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"df = coerce(df, :Age=>Continuous,\n :Tenure=>Continuous,\n :Balance=>Continuous,\n :NumOfProducts=>Continuous,\n :HasCrCard=>Continuous,\n :IsActiveMember=>Continuous,\n :EstimatedSalary=>Continuous,\n :Exited=>Multiclass)\n\nScientificTypes.schema(df)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌─────────────────┬───────────────┬─────────────────────────────────┐\n│ names │ scitypes │ types │\n├─────────────────┼───────────────┼─────────────────────────────────┤\n│ CreditScore │ Count │ Int64 │\n│ Age │ Continuous │ Float64 │\n│ Tenure │ Continuous │ Float64 │\n│ Balance │ Continuous │ Float64 │\n│ NumOfProducts │ Continuous │ Float64 │\n│ HasCrCard │ Continuous │ Float64 │\n│ IsActiveMember │ Continuous │ Float64 │\n│ EstimatedSalary │ Continuous │ Float64 │\n│ Exited │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n└─────────────────┴───────────────┴─────────────────────────────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Unpacking-and-Splitting-Data","page":"SMOTE on Customer Churn Data","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"y, X = unpack(df, ==(:Exited); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌─────────────┬────────────┬────────────┬────────────┬───────────────┬────────────┬────────────────┬─────────────────┐\n│ CreditScore │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │\n│ Int64 │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │\n│ Count │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │\n├─────────────┼────────────┼────────────┼────────────┼───────────────┼────────────┼────────────────┼─────────────────┤\n│ 669.0 │ 31.0 │ 6.0 │ 1.13001e5 │ 1.0 │ 1.0 │ 0.0 │ 40467.8 │\n│ 822.0 │ 37.0 │ 3.0 │ 105563.0 │ 1.0 │ 1.0 │ 0.0 │ 1.82625e5 │\n│ 423.0 │ 36.0 │ 5.0 │ 97665.6 │ 1.0 │ 1.0 │ 0.0 │ 1.18373e5 │\n│ 623.0 │ 21.0 │ 10.0 │ 0.0 │ 2.0 │ 0.0 │ 1.0 │ 1.35851e5 │\n│ 691.0 │ 37.0 │ 7.0 │ 1.23068e5 │ 1.0 │ 1.0 │ 1.0 │ 98162.4 │\n└─────────────┴────────────┴────────────┴────────────┴───────────────┴────────────┴────────────────┴─────────────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Splitting the data into train and test portions is also easy using MLJ's partition function.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"train_inds, test_inds = partition(eachindex(y), 0.8, shuffle=true, rng=Random.Xoshiro(42))\nX_train, X_test = X[train_inds, :], X[test_inds, :]\ny_train, y_test = y[train_inds], y[test_inds]","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"(CategoricalValue{Int64, UInt32}[0, 1, 1, 0, 0, 0, 0, 0, 0, 0 … 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 1, 1, 0, 0, 0 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Oversampling","page":"SMOTE on Customer Churn Data","title":"Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 2037 (25.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 7963 (100.0%)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Looks like we have a class imbalance problem. Let's oversample with SMOTE and set the desired ratios so that the positive minority class is 90% of the majority class","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Xover, yover = smote(X_train, y_train; k=3, ratios=Dict(1=>0.9), rng=42)\ncheckbalance(yover)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 5736 (90.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 6373 (100.0%)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Training-the-Model","page":"SMOTE on Customer Churn Data","title":"Training the Model","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"54-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Let's go for a logistic classifier form MLJLinearModels","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"import Pkg; Pkg.add(\"MLJLinearModels\")","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Before-Oversampling","page":"SMOTE on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"# 1. Load the model\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n\n# 2. Instantiate it\nmodel = LogisticClassifier()\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"trained Machine; caches model-specific representations of data\n model: LogisticClassifier(lambda = 2.220446049250313e-16, …)\n args: \n 1:\tSource @113 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}\n 2:\tSource @972 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#After-Oversampling","page":"SMOTE on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Evaluating-the-Model","page":"SMOTE on Customer Churn Data","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"To evaluate the model, we will use the balanced accuracy metric which equally account for all classes. ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Before-Oversampling-2","page":"SMOTE on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"y_pred = predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"0.5","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#After-Oversampling-2","page":"SMOTE on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"y_pred_over = predict_mode(mach_over, X_test)\n\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"0.57","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Evaluating-the-Model-Revisited","page":"SMOTE on Customer Churn Data","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a 7% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Before-Oversampling-3","page":"SMOTE on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬──────────┬─────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼──────────┼─────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.5 │ 3.29e-16 │ [0.5, 0.5, 0.5 ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴──────────┴─────────────────\n 1 column omitted","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"This looks good. Negligble standard deviation; point estimates are all centered around 0.5.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#After-Oversampling-3","page":"SMOTE on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.SMOTE(k=3, ratios=Dict(1=>0.9), rng=42)\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = LogisticClassifier(lambda = 2.220446049250313e-16, …), …)\n args: \n 1:\tSource @991 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}\n 2:\tSource @939 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"predict_mode(mach_over, X_test) == y_pred_over","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"true","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Now let's cross-validate","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.552 │ 0.0145 │ [0.549, 0.563, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"The improvement is about 5.2% after cross-validation. If we are further to assume scores to be normally distributed, then the 95% confidence interval is 5.2±1.45% improvement. Let's see if this gets any better when we rather use SMOTE-NC in a later example.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#SMOTENC-on-Customer-Churn-Data","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Imbalance\nusing MLJBalancing\nusing CSV\nusing DataFrames\nusing ScientificTypes\nusing CategoricalArrays\nusing MLJ\nusing Plots\nusing Random\nusing HTTP: download","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Loading-Data","page":"SMOTENC on Customer Churn Data","title":"Loading Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"In this example, we will consider the Churn for Bank Customers found on Kaggle where the objective is to predict whether a customer is likely to leave a bank given financial and demographic features. ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"We already considered this dataset using SMOTE, in this example we see if the results are any better using SMOTE-NC.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/smotenc_churn_dataset/churn.csv\", \"./\")\ndf = CSV.read(\"./churn.csv\", DataFrame)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌───────────┬────────────┬──────────┬─────────────┬───────────┬─────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ RowNumber │ CustomerId │ Surname │ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ Int64 │ String31 │ Int64 │ String7 │ String7 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Count │ Textual │ Count │ Textual │ Textual │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├───────────┼────────────┼──────────┼─────────────┼───────────┼─────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 1 │ 15634602 │ Hargrave │ 619 │ France │ Female │ 42 │ 2 │ 0.0 │ 1 │ 1 │ 1 │ 1.01349e5 │ 1 │\n│ 2 │ 15647311 │ Hill │ 608 │ Spain │ Female │ 41 │ 1 │ 83807.9 │ 1 │ 0 │ 1 │ 1.12543e5 │ 0 │\n│ 3 │ 15619304 │ Onio │ 502 │ France │ Female │ 42 │ 8 │ 1.59661e5 │ 3 │ 1 │ 0 │ 1.13932e5 │ 1 │\n│ 4 │ 15701354 │ Boni │ 699 │ France │ Female │ 39 │ 1 │ 0.0 │ 2 │ 0 │ 0 │ 93826.6 │ 0 │\n│ 5 │ 15737888 │ Mitchell │ 850 │ Spain │ Female │ 43 │ 2 │ 1.25511e5 │ 1 │ 1 │ 1 │ 79084.1 │ 0 │\n└───────────┴────────────┴──────────┴─────────────┴───────────┴─────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's get rid of useless columns such as RowNumber and CustomerId","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"df = df[:, Not([:Surname, :RowNumber, :CustomerId])]\n\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌─────────────┬───────────┬─────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ String7 │ String7 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Textual │ Textual │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├─────────────┼───────────┼─────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 619 │ France │ Female │ 42 │ 2 │ 0.0 │ 1 │ 1 │ 1 │ 1.01349e5 │ 1 │\n│ 608 │ Spain │ Female │ 41 │ 1 │ 83807.9 │ 1 │ 0 │ 1 │ 1.12543e5 │ 0 │\n│ 502 │ France │ Female │ 42 │ 8 │ 1.59661e5 │ 3 │ 1 │ 0 │ 1.13932e5 │ 1 │\n│ 699 │ France │ Female │ 39 │ 1 │ 0.0 │ 2 │ 0 │ 0 │ 93826.6 │ 0 │\n│ 850 │ Spain │ Female │ 43 │ 2 │ 1.25511e5 │ 1 │ 1 │ 1 │ 79084.1 │ 0 │\n└─────────────┴───────────┴─────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Coercing-Data","page":"SMOTENC on Customer Churn Data","title":"Coercing Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's coerce the nominal data to Multiclass, the ordinal data to OrderedFactor and the continuous data to Continuous.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"df = coerce(df, \n :Geography => Multiclass, \n :Gender=> Multiclass,\n :CreditScore => OrderedFactor,\n :Age => OrderedFactor,\n :Tenure => OrderedFactor,\n :Balance => Continuous,\n :NumOfProducts => OrderedFactor,\n :HasCrCard => Multiclass,\n :IsActiveMember => Multiclass,\n :EstimatedSalary => Continuous,\n :Exited => Multiclass\n )\n\nScientificTypes.schema(df)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌─────────────────┬────────────────────┬───────────────────────────────────┐\n│ names │ scitypes │ types │\n├─────────────────┼────────────────────┼───────────────────────────────────┤\n│ CreditScore │ OrderedFactor{460} │ CategoricalValue{Int64, UInt32} │\n│ Geography │ Multiclass{3} │ CategoricalValue{String7, UInt32} │\n│ Gender │ Multiclass{2} │ CategoricalValue{String7, UInt32} │\n│ Age │ OrderedFactor{70} │ CategoricalValue{Int64, UInt32} │\n│ Tenure │ OrderedFactor{11} │ CategoricalValue{Int64, UInt32} │\n│ Balance │ Continuous │ Float64 │\n│ NumOfProducts │ OrderedFactor{4} │ CategoricalValue{Int64, UInt32} │\n│ HasCrCard │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ IsActiveMember │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ EstimatedSalary │ Continuous │ Float64 │\n│ Exited │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n└─────────────────┴────────────────────┴───────────────────────────────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Unpacking-and-Splitting-Data","page":"SMOTENC on Customer Churn Data","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"y, X = unpack(df, ==(:Exited); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌─────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬─────────────────────────────────┬─────────────────────────────────┬────────────┬─────────────────────────────────┬─────────────────────────────────┬─────────────────────────────────┬─────────────────┐\n│ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │\n│ CategoricalValue{Int64, UInt32} │ CategoricalValue{String7, UInt32} │ CategoricalValue{String7, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ Float64 │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ Float64 │\n│ OrderedFactor{460} │ Multiclass{3} │ Multiclass{2} │ OrderedFactor{70} │ OrderedFactor{11} │ Continuous │ OrderedFactor{4} │ Multiclass{2} │ Multiclass{2} │ Continuous │\n├─────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼─────────────────────────────────┼─────────────────────────────────┼────────────┼─────────────────────────────────┼─────────────────────────────────┼─────────────────────────────────┼─────────────────┤\n│ 669 │ France │ Female │ 31 │ 6 │ 1.13001e5 │ 1 │ 1 │ 0 │ 40467.8 │\n│ 822 │ France │ Male │ 37 │ 3 │ 105563.0 │ 1 │ 1 │ 0 │ 1.82625e5 │\n│ 423 │ France │ Female │ 36 │ 5 │ 97665.6 │ 1 │ 1 │ 0 │ 1.18373e5 │\n│ 623 │ France │ Male │ 21 │ 10 │ 0.0 │ 2 │ 0 │ 1 │ 1.35851e5 │\n│ 691 │ Germany │ Female │ 37 │ 7 │ 1.23068e5 │ 1 │ 1 │ 1 │ 98162.4 │\n└─────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴─────────────────────────────────┴─────────────────────────────────┴────────────┴─────────────────────────────────┴─────────────────────────────────┴─────────────────────────────────┴─────────────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"train_inds, test_inds = partition(eachindex(y), 0.8, shuffle=true, \n rng=Random.Xoshiro(42))\nX_train, X_test = X[train_inds, :], X[test_inds, :]\ny_train, y_test = y[train_inds], y[test_inds]","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"(CategoricalValue{Int64, UInt32}[0, 1, 1, 0, 0, 0, 0, 0, 0, 0 … 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 1, 1, 0, 0, 0 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Oversampling","page":"SMOTENC on Customer Churn Data","title":"Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 2037 (25.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 7963 (100.0%)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Looks like we have a class imbalance problem. Let's oversample with SMOTE-NC and set the desired ratios so that the positive minority class is 90% of the majority class","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Xover, yover = smotenc(X_train, y_train; k=3, ratios=Dict(1=>0.9), rng=42)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"(12109×10 DataFrame\n Row │ CreditScore Geography Gender Age Tenure Balance NumOfPr ⋯\n │ Cat… Cat… Cat… Cat… Cat… Float64 Cat… ⋯\n───────┼────────────────────────────────────────────────────────────────────────\n 1 │ 551 France Female 38 10 0.0 2 ⋯\n 2 │ 676 France Female 37 5 89634.7 1\n 3 │ 543 France Male 42 4 89838.7 3\n 4 │ 663 France Male 34 10 0.0 1\n 5 │ 621 Germany Female 34 2 91258.5 2 ⋯\n 6 │ 723 France Male 28 4 0.0 2\n 7 │ 735 France Female 21 1 1.78718e5 2\n 8 │ 501 France Male 35 6 99760.8 1\n ⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱\n 12103 │ 551 France Female 40 2 1.68002e5 1 ⋯\n 12104 │ 716 France Female 46 2 1.09379e5 2\n 12105 │ 850 Spain Female 45 10 1.66777e5 1\n 12106 │ 785 France Female 39 9 1.33118e5 1\n 12107 │ 565 Germany Female 39 5 1.44874e5 1 ⋯\n 12108 │ 510 Germany Male 43 0 1.38862e5 1\n 12109 │ 760 France Female 41 2 113419.0 1\n 4 columns and 12094 rows omitted, CategoricalValue{Int64, UInt32}[0, 1, 1, 0, 0, 0, 0, 0, 0, 0 … 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"checkbalance(yover)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 5736 (90.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 6373 (100.0%)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Training-the-Model","page":"SMOTENC on Customer Churn Data","title":"Training the Model","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's find possible models","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"ms = models(matching(Xover, yover))","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"5-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's go for a decision tree classifier from BetaML.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"import Pkg; Pkg.add(\"BetaML\")","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's go for a decision tree from BetaML. We can't go for logistic regression as we did in the SMOTE tutorial because it does not support categotical features.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Before-Oversampling","page":"SMOTENC on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"# 1. Load the model\nDecisionTreeClassifier = @load DecisionTreeClassifier pkg=BetaML\n\n# 2. Instantiate it\nmodel = DecisionTreeClassifier( max_depth=4, rng=Random.Xoshiro(42))\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"import BetaML ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 4, …)\n args: \n 1:\tSource @378 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{2}}, AbstractVector{OrderedFactor{460}}, AbstractVector{OrderedFactor{70}}, AbstractVector{OrderedFactor{11}}, AbstractVector{OrderedFactor{4}}}}\n 2:\tSource @049 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#After-Oversampling","page":"SMOTENC on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌ Info: Training machine(DecisionTreeClassifier(max_depth = 4, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 4, …)\n args: \n 1:\tSource @033 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{2}}, AbstractVector{OrderedFactor{460}}, AbstractVector{OrderedFactor{70}}, AbstractVector{OrderedFactor{11}}, AbstractVector{OrderedFactor{4}}}}\n 2:\tSource @939 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Evaluating-the-Model","page":"SMOTENC on Customer Churn Data","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"To evaluate the model, we will use the balanced accuracy metric which equally accounts for all classes. ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Before-Oversampling-2","page":"SMOTENC on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"y_pred = predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"0.57","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#After-Oversampling-2","page":"SMOTENC on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"y_pred_over = predict_mode(mach_over, X_test)\n\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"0.7","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Although the results do get better compared to when we just used SMOTE, it may hold in this case that the extra categorical features we took into account are not be that important. The difference may be attributed to the decision tree.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Evaluating-the-Model-Revisited","page":"SMOTENC on Customer Churn Data","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a 13% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Before-Oversampling-3","page":"SMOTENC on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:02:54\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.565 │ 0.00623 │ [0.568, 0.554, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Before oversampling, and assuming that the balanced accuracy score is normally distribued we can be 95% confident that the balanced accuracy on new data is 56.5±0.62. Indeed, this agrees a lot with the original point estimate.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#After-Oversampling-3","page":"SMOTENC on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.SMOTENC(k=3, ratios=Dict(1=>0.9), rng=42)\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = DecisionTreeClassifier(max_depth = 4, …), …)\n args: \n 1:\tSource @967 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{2}}, AbstractVector{OrderedFactor{460}}, AbstractVector{OrderedFactor{70}}, AbstractVector{OrderedFactor{11}}, AbstractVector{OrderedFactor{4}}}}\n 2:\tSource @394 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"predict_mode(mach_over, X_test) == y_pred_over","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"true","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Now let's cross-validate","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:07:24\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.677 │ 0.0124 │ [0.678, 0.688, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Fair enough. After oversampling the interval under the same assumptions is 67.7±1.2% which is still a meaningful improvement over 56.5±0.62 that we had prior to oversampling ot the 55.2±1.5% that we had with logistic regression and SMOTE in an earlier example.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Balanced-Bagging-for-Cerebral-Stroke-Prediction","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing MLJBalancing\nusing StatsBase\nusing ScientificTypes\nusing Plots, Measures\nusing Impute\nusing HTTP: download","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Loading-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Loading Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"In this example, we will consider the Cerebral Stroke Prediction Dataset found on Kaggle for the objective of predicting where a stroke has occurred given medical features about patients.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/cerebral_ensemble/cerebral.csv\", \"./\")\ndf = CSV.read(\"./cerebral.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────┬─────────┬────────────┬──────────────┬───────────────┬──────────────┬──────────────┬────────────────┬───────────────────┬────────────────────────────┬──────────────────────────┬────────┐\n│ id │ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │ stroke │\n│ Int64 │ String7 │ Float64 │ Int64 │ Int64 │ String3 │ String15 │ String7 │ Float64 │ Union{Missing, Float64} │ Union{Missing, String15} │ Int64 │\n│ Count │ Textual │ Continuous │ Count │ Count │ Textual │ Textual │ Textual │ Continuous │ Union{Missing, Continuous} │ Union{Missing, Textual} │ Count │\n├───────┼─────────┼────────────┼──────────────┼───────────────┼──────────────┼──────────────┼────────────────┼───────────────────┼────────────────────────────┼──────────────────────────┼────────┤\n│ 30669 │ Male │ 3.0 │ 0 │ 0 │ No │ children │ Rural │ 95.12 │ 18.0 │ missing │ 0 │\n│ 30468 │ Male │ 58.0 │ 1 │ 0 │ Yes │ Private │ Urban │ 87.96 │ 39.2 │ never smoked │ 0 │\n│ 16523 │ Female │ 8.0 │ 0 │ 0 │ No │ Private │ Urban │ 110.89 │ 17.6 │ missing │ 0 │\n│ 56543 │ Female │ 70.0 │ 0 │ 0 │ Yes │ Private │ Rural │ 69.04 │ 35.9 │ formerly smoked │ 0 │\n│ 46136 │ Male │ 14.0 │ 0 │ 0 │ No │ Never_worked │ Rural │ 161.28 │ 19.1 │ missing │ 0 │\n└───────┴─────────┴────────────┴──────────────┴───────────────┴──────────────┴──────────────┴────────────────┴───────────────────┴────────────────────────────┴──────────────────────────┴────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"It's obvious that the id column is useless for predictions so we may as well drop it.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"df = df[:, Not(:id)]\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌─────────┬────────────┬──────────────┬───────────────┬──────────────┬──────────────┬────────────────┬───────────────────┬────────────────────────────┬──────────────────────────┬────────┐\n│ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │ stroke │\n│ String7 │ Float64 │ Int64 │ Int64 │ String3 │ String15 │ String7 │ Float64 │ Union{Missing, Float64} │ Union{Missing, String15} │ Int64 │\n│ Textual │ Continuous │ Count │ Count │ Textual │ Textual │ Textual │ Continuous │ Union{Missing, Continuous} │ Union{Missing, Textual} │ Count │\n├─────────┼────────────┼──────────────┼───────────────┼──────────────┼──────────────┼────────────────┼───────────────────┼────────────────────────────┼──────────────────────────┼────────┤\n│ Male │ 3.0 │ 0 │ 0 │ No │ children │ Rural │ 95.12 │ 18.0 │ missing │ 0 │\n│ Male │ 58.0 │ 1 │ 0 │ Yes │ Private │ Urban │ 87.96 │ 39.2 │ never smoked │ 0 │\n│ Female │ 8.0 │ 0 │ 0 │ No │ Private │ Urban │ 110.89 │ 17.6 │ missing │ 0 │\n│ Female │ 70.0 │ 0 │ 0 │ Yes │ Private │ Rural │ 69.04 │ 35.9 │ formerly smoked │ 0 │\n│ Male │ 14.0 │ 0 │ 0 │ No │ Never_worked │ Rural │ 161.28 │ 19.1 │ missing │ 0 │\n└─────────┴────────────┴──────────────┴───────────────┴──────────────┴──────────────┴────────────────┴───────────────────┴────────────────────────────┴──────────────────────────┴────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Visualize-the-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Visualize the Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Since this dataset is composed mostly of categorical features, a bar chart for each categorical column is a good way to visualize the data.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# Create a bar chart for each column\nbar_charts = []\nfor col in names(df)\n counts = countmap(df[!, col])\n k, v = collect(keys(counts)), collect(values(counts))\n if length(k) < 20\n push!(bar_charts, bar(k, v, legend=false, title=col, color=\"turquoise3\", xrotation=90, margin=6mm))\n end\nend\n\n# Combine bar charts into a grid layout with specified plot size\nplot_res = plot(bar_charts..., layout=(3, 4),\n size=(1300, 500),\n dpi=200\n )\nsavefig(plot_res, \"./assets/cerebral-charts.png\")\n","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"(Image: Mushroom Features Plots)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Our target her is the Stroke variable; notice how imbalanced it is.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Coercing-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Coercing Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Typical models from MLJ assume that elements in each column of a table have some scientific type as defined by the ScientificTypes.jl package. It's often necessary to coerce the types found by default to the appropriate type.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────┬────────────────────────────┬──────────────────────────┐\n│ names │ scitypes │ types │\n├───────────────────┼────────────────────────────┼──────────────────────────┤\n│ gender │ Textual │ String7 │\n│ age │ Continuous │ Float64 │\n│ hypertension │ Count │ Int64 │\n│ heart_disease │ Count │ Int64 │\n│ ever_married │ Textual │ String3 │\n│ work_type │ Textual │ String15 │\n│ Residence_type │ Textual │ String7 │\n│ avg_glucose_level │ Continuous │ Float64 │\n│ bmi │ Union{Missing, Continuous} │ Union{Missing, Float64} │\n│ smoking_status │ Union{Missing, Textual} │ Union{Missing, String15} │\n│ stroke │ Count │ Int64 │\n└───────────────────┴────────────────────────────┴──────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"For instance, here we need to coerce all the data to Multiclass as they are all nominal variables except for Age, avg_glucose_level and bmi which we can treat as continuous","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"df = coerce(df, :gender => Multiclass, :age => Continuous, :hypertension => Multiclass,\n\t:heart_disease => Multiclass, :ever_married => Multiclass, :work_type => Multiclass,\n\t:Residence_type => Multiclass, :avg_glucose_level => Continuous,\n\t:bmi => Continuous, :smoking_status => Multiclass, :stroke => Multiclass,\n)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────┬───────────────┬────────────────────────────────────┐\n│ names │ scitypes │ types │\n├───────────────────┼───────────────┼────────────────────────────────────┤\n│ gender │ Multiclass{3} │ CategoricalValue{String7, UInt32} │\n│ age │ Continuous │ Float64 │\n│ hypertension │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ heart_disease │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ ever_married │ Multiclass{2} │ CategoricalValue{String3, UInt32} │\n│ work_type │ Multiclass{5} │ CategoricalValue{String15, UInt32} │\n│ Residence_type │ Multiclass{2} │ CategoricalValue{String7, UInt32} │\n│ avg_glucose_level │ Continuous │ Float64 │\n│ bmi │ Continuous │ Float64 │\n│ smoking_status │ Multiclass{3} │ CategoricalValue{String15, UInt32} │\n│ stroke │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n└───────────────────┴───────────────┴────────────────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"As shown in the types, some columns have missing values we will impute them using simple random sampling as dropping their rows would mean that we lose a big chunk of the dataset.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"df = Impute.srs(df); disallowmissing!(df)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────────────────────┬────────────┬─────────────────────────────────┬─────────────────────────────────┬───────────────────────────────────┬────────────────────────────────────┬───────────────────────────────────┬───────────────────┬────────────┬────────────────────────────────────┬─────────────────────────────────┐\n│ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │ stroke │\n│ CategoricalValue{String7, UInt32} │ Float64 │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{String3, UInt32} │ CategoricalValue{String15, UInt32} │ CategoricalValue{String7, UInt32} │ Float64 │ Float64 │ CategoricalValue{String15, UInt32} │ CategoricalValue{Int64, UInt32} │\n│ Multiclass{3} │ Continuous │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{5} │ Multiclass{2} │ Continuous │ Continuous │ Multiclass{3} │ Multiclass{2} │\n├───────────────────────────────────┼────────────┼─────────────────────────────────┼─────────────────────────────────┼───────────────────────────────────┼────────────────────────────────────┼───────────────────────────────────┼───────────────────┼────────────┼────────────────────────────────────┼─────────────────────────────────┤\n│ Male │ 3.0 │ 0 │ 0 │ No │ children │ Rural │ 95.12 │ 18.0 │ formerly smoked │ 0 │\n│ Male │ 58.0 │ 1 │ 0 │ Yes │ Private │ Urban │ 87.96 │ 39.2 │ never smoked │ 0 │\n│ Female │ 8.0 │ 0 │ 0 │ No │ Private │ Urban │ 110.89 │ 17.6 │ never smoked │ 0 │\n│ Female │ 70.0 │ 0 │ 0 │ Yes │ Private │ Rural │ 69.04 │ 35.9 │ formerly smoked │ 0 │\n│ Male │ 14.0 │ 0 │ 0 │ No │ Never_worked │ Rural │ 161.28 │ 19.1 │ formerly smoked │ 0 │\n└───────────────────────────────────┴────────────┴─────────────────────────────────┴─────────────────────────────────┴───────────────────────────────────┴────────────────────────────────────┴───────────────────────────────────┴───────────────────┴────────────┴────────────────────────────────────┴─────────────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Unpacking-and-Splitting-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"y, X = unpack(df, ==(:stroke); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────────────────────┬────────────┬─────────────────────────────────┬─────────────────────────────────┬───────────────────────────────────┬────────────────────────────────────┬───────────────────────────────────┬───────────────────┬────────────┬────────────────────────────────────┐\n│ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │\n│ CategoricalValue{String7, UInt32} │ Float64 │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{String3, UInt32} │ CategoricalValue{String15, UInt32} │ CategoricalValue{String7, UInt32} │ Float64 │ Float64 │ CategoricalValue{String15, UInt32} │\n│ Multiclass{3} │ Continuous │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{5} │ Multiclass{2} │ Continuous │ Continuous │ Multiclass{3} │\n├───────────────────────────────────┼────────────┼─────────────────────────────────┼─────────────────────────────────┼───────────────────────────────────┼────────────────────────────────────┼───────────────────────────────────┼───────────────────┼────────────┼────────────────────────────────────┤\n│ Female │ 37.0 │ 0 │ 0 │ Yes │ Private │ Urban │ 103.66 │ 36.1 │ smokes │\n│ Female │ 78.0 │ 0 │ 0 │ No │ Private │ Rural │ 83.97 │ 39.6 │ formerly smoked │\n│ Female │ 2.0 │ 0 │ 0 │ No │ children │ Urban │ 98.66 │ 17.0 │ smokes │\n│ Female │ 62.0 │ 0 │ 0 │ No │ Private │ Rural │ 205.41 │ 27.8 │ smokes │\n│ Male │ 14.0 │ 0 │ 0 │ No │ Private │ Rural │ 118.18 │ 24.5 │ never smoked │\n└───────────────────────────────────┴────────────┴─────────────────────────────────┴─────────────────────────────────┴───────────────────────────────────┴────────────────────────────────────┴───────────────────────────────────┴───────────────────┴────────────┴────────────────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Splitting the data into train and test portions is also easy using MLJ's partition function. stratify=y guarantees that the data is distributed in the same proportions as the original dataset in both splits which is more representative of the real world.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"(X_train, X_test), (y_train, y_test) = partition(\n\t(X, y),\n\t0.8,\n\tmulti = true,\n\tshuffle = true,\n\tstratify = y,\n\trng = Random.Xoshiro(42)\n)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"⚠️ Always split the data before oversampling. If your test data has oversampled observations then train-test contamination has occurred; novel observations will not come from the oversampling function.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Oversampling","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Oversampling","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"It was obvious from the bar charts that there is a severe imbalance problem. Let's look at that again.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"1: ▇ 783 (1.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 42617 (100.0%)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Indeed, may be too severe for most models.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Training-the-Model","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Training the Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"ms = models(matching(Xover, yover))","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"6-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = OneRuleClassifier, package_name = OneRule, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Let's go for a DecisionTreeClassifier","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"import Pkg; Pkg.add(\"BetaML\")","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":" Resolving package versions...\n Installed MLJBalancing ─ v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Project.toml`\n [45f359ea] + MLJBalancing v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Manifest.toml`\n [45f359ea] + MLJBalancing v0.1.0\nPrecompiling project...\n ✓ MLJBalancing\n 1 dependency successfully precompiled in 25 seconds. 262 already precompiled.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Load-and-Construct","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Load and Construct","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# 1. Load the model\nDecisionTreeClassifier = @load DecisionTreeClassifier pkg=BetaML\n\n# 2. Instantiate it\nmodel = DecisionTreeClassifier(max_depth=4)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"import BetaML ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\n\nDecisionTreeClassifier(\n max_depth = 4, \n min_gain = 0.0, \n min_records = 2, \n max_features = 0, \n splitting_criterion = BetaML.Utils.gini, \n rng = Random._GLOBAL_RNG())","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Wrap-in-a-machine-and-fit!","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Wrap in a machine and fit!","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"trained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 4, …)\n args: \n 1:\tSource @245 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{3}}}}\n 2:\tSource @251 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Evaluate-the-Model","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Evaluate the Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"y_pred = MLJ.predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"0.5","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Training-BalancedBagging-Model","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Training BalancedBagging Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"The results suggest that the model is just as good as random guessing. Let's see if this gets better by using a BalancedBaggingClassifier. This classifier trains T of the given model on T undersampled versions of the dataset where in each undersampled version there are as much majority examples as there are minority examples.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"This approach can allow us to workaround the imbalance issue without losing any data. For instance, if we set T=Int(100/1.8) (which is the default) then on average all majority examples will be used in one of the T bags.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Load-and-Construct-2","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Load and Construct","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"bagging_model = BalancedBaggingClassifier(model=model, T=30, rng=Random.Xoshiro(42))","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"BalancedBaggingClassifier(\n model = DecisionTreeClassifier(\n max_depth = 4, \n min_gain = 0.0, \n min_records = 2, \n max_features = 0, \n splitting_criterion = BetaML.Utils.gini, \n rng = Random._GLOBAL_RNG()), \n T = 30, \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1))","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Wrap-in-a-machine-and-fit!-2","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Wrap in a machine and fit!","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(bagging_model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"trained Machine; does not cache data\n model: BalancedBaggingClassifier(model = DecisionTreeClassifier(max_depth = 4, …), …)\n args: \n 1:\tSource @005 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{3}}}}\n 2:\tSource @531 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Evaluate-the-Model-2","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Evaluate the Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"y_pred = MLJ.predict_mode(mach_over, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"0.77","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"This is a dramatic improvement over what we had before. Let's confirm with cross-validation.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy, operation=predict_mode) ","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:01:40\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.772 │ 0.0146 │ [0.738, 0.769, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Under the normality of scores, the 95% confidence interval is 77.2±1.4% for the balanced accuracy.","category":"page"},{"location":"examples/","page":"More Examples","title":"More Examples","text":" \n
\n \n \n \n \n \n \n \n \n\n
\n\n","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#SMOTE-Tomek-for-Ethereum-Fraud-Detection","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Imbalance\nusing MLJBalancing\nusing CSV\nusing DataFrames\nusing ScientificTypes\nusing CategoricalArrays\nusing MLJ\nusing Plots\nusing Random\nusing Impute\nusing HTTP: download","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Loading-Data","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Loading Data","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"In this example, we will consider the Ethereum Fraud Detection Dataset found on Kaggle where the objective is to predict whether an Ethereum transaction is fraud or not (called FLAG) given some features about the transaction.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/fraud_detection/transactions.csv\", \"./\")\n\ndf = CSV.read(\"./transactions.csv\", DataFrame)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"There are plenty of useless columns that we can get rid of such as Column1, Index and probably, Address. We also have to get rid of the categorical features because SMOTE won't be able to deal with those and it leaves us with more options for the model.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"df = df[:,\n\tNot([\n\t\t:Column1,\n\t\t:Index,\n\t\t:Address,\n\t\tSymbol(\" ERC20 most sent token type\"),\n\t\tSymbol(\" ERC20_most_rec_token_type\"),\n\t]),\n] \nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"If you scroll through the printed data frame, you find that some columns also have Missing for their element type, meaning that they may be containing missing values. We will use linear interpolation, last-observation carried forward and next observation carried backward techniques to fill up the missing values. This will allow us to call disallowmissing!(df) to return a dataframe where Missing is not an element type for any column.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"df = Impute.interp(df) |> Impute.locf() |> Impute.nocb(); disallowmissing!(df)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Coercing-Data","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Coercing Data","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Let's look at the schema first","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"┌──────────────────────────────────────────────────────┬────────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────────────────────────────────────────────┼────────────┼─────────┤\n│ FLAG │ Count │ Int64 │\n│ Avg min between sent tnx │ Continuous │ Float64 │\n│ Avg min between received tnx │ Continuous │ Float64 │\n│ Time Diff between first and last (Mins) │ Continuous │ Float64 │\n│ Sent tnx │ Count │ Int64 │\n│ Received Tnx │ Count │ Int64 │\n│ Number of Created Contracts │ Count │ Int64 │\n│ Unique Received From Addresses │ Count │ Int64 │\n│ Unique Sent To Addresses │ Count │ Int64 │\n│ min value received │ Continuous │ Float64 │\n│ max value received │ Continuous │ Float64 │\n│ avg val received │ Continuous │ Float64 │\n│ min val sent │ Continuous │ Float64 │\n│ max val sent │ Continuous │ Float64 │\n│ avg val sent │ Continuous │ Float64 │\n│ min value sent to contract │ Continuous │ Float64 │\n│ ⋮ │ ⋮ │ ⋮ │\n└──────────────────────────────────────────────────────┴────────────┴─────────┘\n 30 rows omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"The FLAG target should definitely be Multiclass, the rest seems fine.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"df = coerce(df, :FLAG =>Multiclass)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"┌──────────────────────────────────────────────────────┬───────────────┬────────\n│ names │ scitypes │ types ⋯\n├──────────────────────────────────────────────────────┼───────────────┼────────\n│ FLAG │ Multiclass{2} │ Categ ⋯\n│ Avg min between sent tnx │ Continuous │ Float ⋯\n│ Avg min between received tnx │ Continuous │ Float ⋯\n│ Time Diff between first and last (Mins) │ Continuous │ Float ⋯\n│ Sent tnx │ Count │ Int64 ⋯\n│ Received Tnx │ Count │ Int64 ⋯\n│ Number of Created Contracts │ Count │ Int64 ⋯\n│ Unique Received From Addresses │ Count │ Int64 ⋯\n│ Unique Sent To Addresses │ Count │ Int64 ⋯\n│ min value received │ Continuous │ Float ⋯\n│ max value received │ Continuous │ Float ⋯\n│ avg val received │ Continuous │ Float ⋯\n│ min val sent │ Continuous │ Float ⋯\n│ max val sent │ Continuous │ Float ⋯\n│ avg val sent │ Continuous │ Float ⋯\n│ min value sent to contract │ Continuous │ Float ⋯\n│ ⋮ │ ⋮ │ ⋱\n└──────────────────────────────────────────────────────┴───────────────┴────────\n 1 column and 30 rows omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Unpacking-and-Splitting-Data","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"y, X = unpack(df, ==(:FLAG); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Splitting the data into train and test portions is also easy using MLJ's partition function.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"(X_train, X_test), (y_train, y_test) = partition(\n\t(X, y),\n\t0.8,\n\tmulti = true,\n\tshuffle = true,\n\tstratify = y,\n\trng = Random.Xoshiro(41)\n)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Resampling","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Resampling","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"This signals a potential class imbalance problem. Let's consider using SMOTE-Tomek to resample this data. The SMOTE-Tomek algorithm is nothing but SMOTE followed by TomekUndersampler. We can wrap these in a pipeline along with a classification model for predictions using BalancedModel from MLJBalancing. Let's go for a RandomForestClassifier from DecisionTree.jl for the model.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"import Pkg; Pkg.add(\"DecisionTree\")","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Construct-the-Resampling-and-Classification-Models","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Construct the Resampling & Classification Models","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"oversampler = Imbalance.MLJ.SMOTE(ratios=Dict(1=>0.5), rng=Random.Xoshiro(42))\nundersampler = Imbalance.MLJ.TomekUndersampler(min_ratios=Dict(0=>1.3), force_min_ratios=true)\nRandomForestClassifier = @load RandomForestClassifier pkg=DecisionTree\nmodel = RandomForestClassifier(n_trees=2, rng=Random.Xoshiro(42))","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"RandomForestClassifier(\n max_depth = -1, \n min_samples_leaf = 1, \n min_samples_split = 2, \n min_purity_increase = 0.0, \n n_subfeatures = -1, \n n_trees = 2, \n sampling_fraction = 0.7, \n feature_importance = :impurity, \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1))","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Form-the-Pipeline-using-BalancedModel","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Form the Pipeline using BalancedModel","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"balanced_model = BalancedModel(model=model, balancer1=oversampler, balancer2=undersampler)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"BalancedModelProbabilistic(\n model = RandomForestClassifier(\n max_depth = -1, \n min_samples_leaf = 1, \n min_samples_split = 2, \n min_purity_increase = 0.0, \n n_subfeatures = -1, \n n_trees = 2, \n sampling_fraction = 0.7, \n feature_importance = :impurity, \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1)), \n balancer1 = SMOTE(\n k = 5, \n ratios = Dict(1 => 0.5), \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1), \n try_preserve_type = true), \n balancer2 = TomekUndersampler(\n min_ratios = Dict(0 => 1.3), \n force_min_ratios = true, \n rng = TaskLocalRNG(), \n try_preserve_type = true))","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Now we can treat balanced_model like any MLJ model.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Fit-the-BalancedModel","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Fit the BalancedModel","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = RandomForestClassifier(max_depth = -1, …), …)\n args: \n 1:\tSource @967 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}\n 2:\tSource @913 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Validate-the-BalancedModel","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Validate the BalancedModel","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"PerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.93 │ 0.00757 │ [0.927, 0.936, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Compare-with-RandomForestClassifier-only","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Compare with RandomForestClassifier only","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"To see if this represents any form of improvement, fitting and validating the original model by itself.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train, scitype_check_level=0)\nfit!(mach)\n\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"PerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.908 │ 0.00932 │ [0.903, 0.898, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Assuming normal scores, the 95% confidence interval was 90.8±0.9 and after resampling it has become 93±0.7 which corresponds to a small improvement in accuracy.","category":"page"},{"location":"examples/Colab/#Google-Colab","page":"Google Colab","title":"Google Colab","text":"","category":"section"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"It is possible to run tutorials found in the examples section or API documentation on Google Colab (using provided link or icon). It should be evident how so by launching the notebook. This section describes what happens under the hood.","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"The first cell runs the following bash script to install Julia:","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"%%capture\n%%shell\nif ! command -v julia 3>&1 > /dev/null\nthen\n wget -q 'https://julialang-s3.julialang.org/bin/linux/x64/1.7/julia-1.7.2-linux-x86_64.tar.gz' \\\n -O /tmp/julia.tar.gz\n tar -x -f /tmp/julia.tar.gz -C /usr/local --strip-components 1\n rm /tmp/julia.tar.gz\nfi\njulia -e 'using Pkg; pkg\"add IJulia; precompile;\"'\necho 'Done'","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"Once that is done, one can change the runtime to Julia by choosing Runtime from the toolbar then Change runtime type and at this point they can delete the cell","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"Sincere thanks to Julia-on-Colab for making this possible.","category":"page"},{"location":"#Imbalance.jl","page":"Introduction","title":"Imbalance.jl","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"(Image: Imbalance)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"A Julia package with resampling methods to correct for class imbalance in a wide variety of classification settings.","category":"page"},{"location":"#Installation","page":"Introduction","title":"Installation","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"import Pkg;\nPkg.add(\"Imbalance\")","category":"page"},{"location":"#Implemented-Methods","page":"Introduction","title":"Implemented Methods","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"The package implements the following resampling algorithms","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Random Oversampling\nRandom Walk Oversampling (RWO)\nRandom Oversampling Examples (ROSE)\nSynthetic Minority Oversampling Technique (SMOTE)\nBorderline SMOTE1\nSMOTE-Nominal (SMOTE-N)\nSMOTE-Nominal Categorical (SMOTE-NC)\nRandom Undersampling\nCluster Undersampling\nEditedNearestNeighbors Undersampling\nTomek Links Undersampling\nBalanced Bagging Classifier (@MLJBalancing.jl)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"To see various examples where such methods help improve classification performance, check the tutorials sections of the documentation.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Interested in contributing with more? Check this.","category":"page"},{"location":"#Quick-Start","page":"Introduction","title":"Quick Start","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"We will illustrate using the package to oversample withSMOTE; however, all other implemented oversampling methods follow the same pattern.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Let's start by generating some dummy imbalanced data:","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using Imbalance\n\n# Set dataset properties then generate imbalanced data\nclass_probs = [0.5, 0.2, 0.3] # probability of each class \nnum_rows, num_continuous_feats = 100, 5\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; class_probs, rng=42)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"In following code blocks, it will be assumed that X and y are readily available.","category":"page"},{"location":"#Standard-API","page":"Introduction","title":"Standard API","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"All methods by default support a pure functional interface.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using Imbalance\n\n# Apply SMOTE to oversample the classes\nXover, yover = smote(X, y; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"In following code blocks, it will be assumed that X and y are readily available.","category":"page"},{"location":"#MLJ-Interface","page":"Introduction","title":"MLJ Interface","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"All methods support the MLJ interface where instead of directly calling the method, one instantiates a model for the method while optionally passing the keyword parameters found in the functional interface then wraps the model in a machine and follows by calling transform on the machine and data.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using MLJ\n\n# Load the model\nSMOTE = @load SMOTE pkg=Imbalance\n\n# Create an instance of the model \noversampler = SMOTE(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\n# Wrap it in a machine\nmach = machine(oversampler)\n\n# Provide the data to transform \nXover, yover = transform(mach, X, y)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"All implemented oversampling methods are considered static transforms and hence, no fit is required. ","category":"page"},{"location":"#Pipelining-Models","page":"Introduction","title":"Pipelining Models","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"If MLJBalancing is also used, an arbitrary number of resampling methods from Imbalance.jl can be wrapped with a classification model from MLJ to function as a unified model where resampling automatically takes place on given data before training the model (and is bypassed during prediction).","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using MLJ, MLJBalancing\n\n# grab two resamplers and a classifier\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\nSMOTE = @load SMOTE pkg=Imbalance verbosity=0\nTomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0\n\noversampler = SMOTE(k=5, ratios=1.0, rng=42)\nundersampler = TomekUndersampler(min_ratios=0.5, rng=42)\nlogistic_model = LogisticClassifier()\n\n# wrap the oversampler, undersample and classification model together\nbalanced_model = BalancedModel(model=logistic_model, \n balancer1=oversampler, balancer2=undersampler)\n\n# behaves like a single model\nmach = machine(balanced_model, X, y);\nfit!(mach, verbosity=0)\npredict(mach, X)","category":"page"},{"location":"#Table-Transforms-Interface","page":"Introduction","title":"Table Transforms Interface","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"The TableTransforms interface operates on single tables; it assumes that y is one of the columns of the given table. Thus, it follows a similar pattern to the MLJ interface except that the index of y is a required argument while instantiating the model and the data to be transformed via apply is only one table Xy.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using Imbalance\nusing Imbalance.TableTransforms\nusing TableTransforms\n\n# Generate imbalanced data\nnum_rows = 200\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate SMOTE model\noversampler = SMOTE(y_ind; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler # can chain with other table transforms \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"The reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Notice that because the interfaces of MLJ and TableTransforms use the same model names, you will have to specify the source of the model if both are used in the same file (e.g., Imbalance.TableTransforms.SMOTE) for the example above.","category":"page"},{"location":"#Features","page":"Introduction","title":"Features","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"Supports multi-class variants of the algorithms and both nominal and continuous features\nProvides MLJ and TableTransforms interfaces aside from the default pure functional interface\nGeneric by supporting table input/output formats as well as matrices\nSupports tables regardless to whether the target is a separate column or one of the columns\nSupports automatic encoding and decoding of nominal features","category":"page"},{"location":"#Rationale","page":"Introduction","title":"Rationale","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"Most if not all machine learning algorithms can be viewed as a form of empirical risk minimization where the object is to find the parameters theta that for some loss function L minimize ","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"hattheta = argmin_theta frac1N sum_i=1^N L(f_theta(x_i) y_i)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"The underlying assumption is that minimizing this empirical risk corresponds to approximately minimizing the true risk which considers all examples in the populations which would imply that f_theta is approximately the true target function f that we seek to model.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"In a multi-class setting with K classes, one can write","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"hattheta = argmin_theta left( frac1N_1 sum_i in C_1 L(f_theta(x_i) y_i) + frac1N_2 sum_i in C_2 L(f_theta(x_i) y_i) + ldots + frac1N_K sum_i in C_K L(f_theta(x_i) y_i) right)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Class imbalance occurs when some classes have much fewer examples than other classes. In this case, the terms corresponding to smaller classes contribute minimally to the sum which makes it possible for any learning algorithm to find an approximate solution to minimizing the empirical risk that mostly only minimizes the over the significant sums. This yields a hypothesis f_theta that may be very different from the true target f with respect to the minority classes which may be the most important for the application in question.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"One obvious possible remedy is to weight the smaller sums so that a learning algorithm more easily avoids approximate solutions that exploit their insignificance which can be seen to be equivalent to repeating examples of the observations in minority classes. This can be achieved by naive random oversampling which is offered by this package along with other more advanced oversampling methods that function by generating synthetic data or deleting existing ones. You can read more about the class imbalance problem and learn about various algorithms implemented in this package by reading this series of articles on Medium.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"To our knowledge, there are no existing maintained Julia packages that implement resampling algorithms for multi-class classification problems or that handle both nominal and continuous features. This has served as a primary motivation for the creation of this package.","category":"page"}] +[{"location":"algorithms/extra_algorithms/#Extras","page":"Extras","title":"Extras","text":"","category":"section"},{"location":"algorithms/extra_algorithms/#Generate-Imbalanced-Data","page":"Extras","title":"Generate Imbalanced Data","text":"","category":"section"},{"location":"algorithms/extra_algorithms/","page":"Extras","title":"Extras","text":"generate_imbalanced_data","category":"page"},{"location":"algorithms/extra_algorithms/#Imbalance.generate_imbalanced_data","page":"Extras","title":"Imbalance.generate_imbalanced_data","text":"generate_imbalanced_data(\n num_rows, num_continuous_feats;\n means=nothing, min_sep=1.0, stds=nothing,\n num_vals_per_category = [],\n class_probs = [0.8, 0.2],\n type= \"ColTable\", insert_y= nothing,\n rng= default_rng(),\n)\n\nGenerate num_rows observations with target y respecting given probabilities of each class. Supports generating continuous features with a specific mean and variance and categorical features given the number of levels in each variable.\n\nArguments\n\nnum_rows::Integer: Number of observations to generate\nnum_continuous_feats::Integer: Number of continuous features to generate\nmeans::AbstractVector=nothing: A vector of means for each continuous feature (must be as long as num_continuous_feats). If nothing, then will be set randomly\nmin_sep::AbstractFloat=1.0: Minimum distance between any two randomly chosen means. Will have no effect if the means are given.\nstds::AbstractVector=nothing: A vector of standard deviations for each continuous feature (must be as long as num_continuous_feats). If nothing, then will be set randomly\nnum_vals_per_category::AbstractVector=[]: A vector of the number of levels of each extra categorical feature. the number of categorical features is inferred from this.\nclass_probs::AbstractVector{<:AbstractFloat}=[0.8, 0.2]: A vector of probabilities of each class. The number of classes is inferred from this vector.\ntype::AbstractString=\"ColTable\": Can be \"Matrix\" or \"ColTable\". In the latter case, a named-tuple of vectors is returned.\ninsert_y::Integer=nothing: If not nothing, insert the class labels column at the given index in the table\nrng::Union{AbstractRNG, Integer}=default_rng(): Random number generator. If integer then used as seed in Random.Xoshiro(seed) if the Julia VERSION supports it. Otherwise, uses Random.MersenneTwister(seed).\n\nReturns\n\nX:: A column table or matrix with generated imbalanced data with num_rows rows and num_continuous_feats + length(num_vals_per_category) columns. If insert_y is specified as in integer then y is also inserted at the specified index as an extra column.\ny::CategoricalArray: An abstract vector of class labels with labels 0, 1, 2, ..., k-1 where k=length(class_probs)\n\nExample\n\nusing Imbalance\nusing Plots\n\nnum_rows = 500\nnum_features = 2\n# generating continuous features given mean and std\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n\tmeans = [1.0, 4.0, [7.0 9.0]],\n\tstds = [1.0, [0.5 0.8], 2.0],\n\tclass_probs=[0.5, 0.2, 0.3],\n\ttype=\"Matrix\",\n\trng = 42,\n)\n\np = plot()\n[scatter!(p, X[:, 1][y.==yi], X[:, 2][y.==yi], label = \"$y=yi$\") for yi in unique(y)]\n\njulia> plot(p)\n\n(Image: generated data)\n\n# generating continuous features with random mean and std\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n min_sep=0.3, \n\tclass_probs=[0.5, 0.2, 0.3],\n\ttype=\"Matrix\",\n\trng = 33,\n)\n\np = plot()\n[scatter!(p, X[:, 1][y.==yi], X[:, 2][y.==yi], label = \"$y=yi$\") for yi in unique(y)]\n\njulia> plot(p)\n\n(Image: generated data)\n\nnum_rows = 500\nnum_features = 2\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n num_vals_per_category = [3, 5, 2],\n\tclass_probs=[0.9, 0.1],\n\tinsert_y=4,\n\ttype=\"ColTable\",\n\trng = 33,\n)\n\njulia> X\n(Column1 = [0.883, 0.9, 0.577 … 0.887,],\n Column2 = [0.578, 0.718, 0.378 … 0.573,],\n Column3 = [2.0, 2.0, 3.0, … 2.0,],\n Column4 = [0.0, 0.0, 0.0, … 0.0,],\n Column5 = [2.0, 3.0, 4.0, … 4.0,],\n Column6 = [1.0, 1.0, 2.0, … 1.0,],)\n\n\n\n\n\n","category":"function"},{"location":"algorithms/extra_algorithms/#Check-Balance-of-Data","page":"Extras","title":"Check Balance of Data","text":"","category":"section"},{"location":"algorithms/extra_algorithms/","page":"Extras","title":"Extras","text":"checkbalance","category":"page"},{"location":"algorithms/extra_algorithms/#Imbalance.checkbalance","page":"Extras","title":"Imbalance.checkbalance","text":"checkbalance(y; reference=\"majority\")\n\nA visual version of StatsBase.countmap that returns nothing and prints how many observations in the dataset belong to each class and their percentage relative to the size of majority or minority class.\n\nArguments\n\ny::AbstractVector: A vector of categorical values to test for imbalance\nreference=\"majority\": Either \"majority\" or \"minority\" and decides whether the percentage should be relative to the size of majority or minority class.\n\nExample\n\nnum_rows = 50000\nnum_features = 2\nX, y = generate_imbalanced_data(\n\tnum_rows,\n\tnum_features;\n\tclass_probs=[0.8, 0.2],\n\ttype=\"Matrix\",\n\trng = 42,\n)\n\njulia> Imbalance.checkbalance(y; ref=\"majority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 10034 (25.1%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 39966 (100.0%) \n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 10034 (100.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 39966 (398.3%) \n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#Oversampling-Algorithms","page":"Oversampling","title":"Oversampling Algorithms","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"The following table portrays the supported oversampling algorithms, whether the mechanism repeats or generates data and the supported types of data.","category":"page"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"Oversampling Method Mechanism Supported Data Types\nRandom Oversampler Repeat existing data Continuous and/or nominal\nRandom Walk Oversampler Generate synthetic data Continuous and/or nominal\nROSE Generate synthetic data Continuous\nSMOTE Generate synthetic data Continuous\nBorderline SMOTE1 Generate synthetic data Continuous\nSMOTE-N Generate synthetic data Nominal\nSMOTE-NC Generate synthetic data Continuous and nominal","category":"page"},{"location":"algorithms/oversampling_algorithms/#Random-Oversampler","page":"Oversampling","title":"Random Oversampler","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"random_oversample","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.random_oversample","page":"Oversampling","title":"Imbalance.random_oversample","text":"random_oversample(\n X, y; \n ratios=1.0, rng=default_rng(), \n try_preserve_type=true\n)\n\nDescription\n\nNaively oversample a dataset by randomly repeating existing observations with replacement.\n\nPositional Arguments\n\nX: A matrix of real numbers or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\n# apply random oversampling\nXover, yover = random_oversample(X, y; ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the RandomOversampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nRandomOversampler = @load RandomOversampler pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = RandomOversampler(ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate Random Oversampler model\noversampler = RandomOversampler(y_ind; ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#Random-Walk-Oversampler","page":"Oversampling","title":"Random Walk Oversampler","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"random_walk_oversample","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.random_walk_oversample","page":"Oversampling","title":"Imbalance.random_walk_oversample","text":"random_walk_oversample(\n\tX, y, cat_inds;\n\tratios=1.0, rng=default_rng(),\n\ttry_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using random walk oversampling as presented in [1]. \n\nPositional Arguments\n\nX: A matrix of floats or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\ncat_inds::AbstractVector{<:Int}: A vector of the indices of the nominal features. Supplied only if X is a matrix. Otherwise, they are inferred from the table's scitypes.\n\nKeyword Arguments\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\nusing ScientificTypes\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows = 100\nnum_continuous_feats = 3\n# want two categorical features with three and two possible values respectively\nnum_vals_per_category = [3, 2]\n\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, num_vals_per_category, rng=42) \n\t\t\t\t\t\t\t\t \njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\njulia> ScientificTypes.schema(X).scitypes\n(Continuous, Continuous, Continuous, Continuous, Continuous)\n# coerce nominal columns to a finite scitype (multiclass or ordered factor)\nX = coerce(X, :Column4=>Multiclass, :Column5=>Multiclass)\n\n# apply random walk oversampling\nXover, yover = random_walk_oversample(X, y; \n ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the RandomWalkOversampling model and pass the \tpositional arguments (excluding cat_inds) to the transform method. \n\nusing MLJ\nRandomWalkOversampler = @load RandomWalkOversampler pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = RandomWalkOversampler(ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser. Note that only Table input is supported by the MLJ interface for this method.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind \tmust be specified to the constructor to specify which column y is followed by other keyword arguments. \tOnly Xy is provided while applying the transform.\n\nusing Imbalance\nusing ScientificTypes\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_continuous_feats = 3\ny_ind = 2\n\n# generate a table and categorical vector accordingly\nXy, _ = generate_imbalanced_data(num_rows, num_continuous_feats; insert_y=y_ind,\n class_probs= [0.5, 0.2, 0.3], num_vals_per_category=[3, 2],\n rng=42) \n\n# Table must have only finite or continuous scitypes \nXy = coerce(Xy, :Column2=>Multiclass, :Column5=>Multiclass, :Column6=>Multiclass)\n\n# Initiate Random Walk Oversampler model\noversampler = RandomWalkOversampler(y_ind;\n ratios=Dict(1=>1.0, 2=> 0.9, 3=>0.9), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Zhang, H., & Li, M. (2014). RWO-Sampling: A random walk over-sampling approach to imbalanced data classification. Information Fusion, 25, 4-20.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#ROSE","page":"Oversampling","title":"ROSE","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"rose","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.rose","page":"Oversampling","title":"Imbalance.rose","text":"rose(\n X, y; \n s=1.0, ratios=1.0, rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using ROSE (Random Oversampling Examples) algorithm to correct for class imbalance as presented in [1]\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\ns::float=1.0: A parameter that proportionally controls the bandwidth of the Gaussian kernel\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\n# apply ROSE\nXover, yover = rose(X, y; s=0.3, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the ROSE model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nROSE = @load ROSE pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = ROSE(s=0.3, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 200\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate ROSE model\noversampler = ROSE(y_ind; s=0.3, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] G Menardi, N. Torelli, “Training and assessing classification rules with imbalanced data,” Data Mining and Knowledge Discovery, 28(1), pp.92-122, 2014.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#SMOTE","page":"Oversampling","title":"SMOTE","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"smote","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.smote","page":"Oversampling","title":"Imbalance.smote","text":"smote(\n X, y;\n k=5, ratios=1.0, rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using SMOTE (Synthetic Minority Oversampling Techniques) algorithm to correct for class imbalance as presented in [1]\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class.\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\n# apply SMOTE\nXover, yover = smote(X, y; k = 5, ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the SMOTE model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nSMOTE = @load SMOTE pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = SMOTE(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 200\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate SMOTE model\noversampler = SMOTE(y_ind; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#Borderline-SMOTE1","page":"Oversampling","title":"Borderline SMOTE1","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"borderline_smote1","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.borderline_smote1","page":"Oversampling","title":"Imbalance.borderline_smote1","text":"borderline_smote1(\n X, y;\n m=5, k=5, ratios=1.0, rng=default_rng(),\n try_preserve_type=true, verbosity=1\n)\n\nDescription\n\nOversamples a dataset using borderline SMOTE1 algorithm to correct for class imbalance as presented in [1]\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nm::Integer=5: The number of neighbors to consider while checking the BorderlineSMOTE1 condition. Should be within the range 0 < m < N where N is the number of observations in the data. It will be automatically set to N-1 if N ≤ m.\nk::Integer=5: Number of nearest neighbors to consider in the SMOTE part of the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class. It will be automatically set to l-1 for any class with l points where l ≤ k.\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nverbosity::Integer=1: Whenever higher than 0 info regarding the points that will participate in oversampling is logged.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 1000, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n stds=[0.1 0.1 0.1], min_sep=0.01, class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 200 (40.8%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 310 (63.3%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 490 (100.0%) \n\n# apply BorderlineSMOTE1\nXover, yover = borderline_smote1(X, y; m = 3, \n k = 5, ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 392 (80.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 441 (90.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 490 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the BorderlineSMOTE1 model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nBorderlineSMOTE1 = @load BorderlineSMOTE1 pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = BorderlineSMOTE1(m=3, k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 1000\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], min_sep=0.01, insert_y=y_ind, rng=42)\n\n# Initiate BorderlineSMOTE1 Oversampler model\noversampler = BorderlineSMOTE1(y_ind; m=3, k=5, \n ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) \n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Han, H., Wang, W.-Y., & Mao, B.-H. (2005). Borderline-SMOTE: A new over-sampling method in imbalanced data sets learning. In D.S. Huang, X.-P. Zhang, & G.-B. Huang (Eds.), Advances in Intelligent Computing (pp. 878-887). Springer. \n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#SMOTE-N","page":"Oversampling","title":"SMOTE-N","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"smoten","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.smoten","page":"Oversampling","title":"Imbalance.smoten","text":"smoten(\n X, y;\n k=5, ratios=1.0, rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using SMOTE-N (Synthetic Minority Oversampling Techniques-Nominal) algorithm to correct for class imbalance as presented in [1]. This is a variant of SMOTE to deal with datasets where all features are nominal.\n\nPositional Arguments\n\nX: A matrix of integers or a table with element scitypes that subtype Finite. That is, for table inputs each column should have either OrderedFactor or Multiclass as the element scitype.\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class.\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\nusing ScientificTypes\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows = 100\nnum_continuous_feats = 0\n# want two categorical features with three and two possible values respectively\nnum_vals_per_category = [3, 2]\n\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, num_vals_per_category, rng=42) \njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\njulia> ScientificTypes.schema(X).scitypes\n(Count, Count)\n\n# coerce to a finite scitype (multiclass or ordered factor)\nX = coerce(X, autotype(X, :few_to_finite))\n\n# apply SMOTEN\nXover, yover = smoten(X, y; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the SMOTEN model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nSMOTEN = @load SMOTEN pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = SMOTEN(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing ScientificTypes\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_continuous_feats = 0\ny_ind = 2\n# generate a table and categorical vector accordingly\nXy, _ = generate_imbalanced_data(num_rows, num_continuous_feats; insert_y=y_ind,\n class_probs= [0.5, 0.2, 0.3], num_vals_per_category=[3, 2],\n rng=42) \n\n# Table must have only finite scitypes \nXy = coerce(Xy, :Column1=>Multiclass, :Column2=>Multiclass, :Column3=>Multiclass)\n\n# Initiate SMOTEN model\noversampler = SMOTEN(y_ind; k=5, ratios=Dict(1=>1.0, 2=> 0.9, 3=>0.9), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nThe reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.\n\nIllustration\n\nA full basic example can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/oversampling_algorithms/#SMOTE-NC","page":"Oversampling","title":"SMOTE-NC","text":"","category":"section"},{"location":"algorithms/oversampling_algorithms/","page":"Oversampling","title":"Oversampling","text":"smotenc","category":"page"},{"location":"algorithms/oversampling_algorithms/#Imbalance.smotenc","page":"Oversampling","title":"Imbalance.smotenc","text":"smotenc(\n X, y, split_ind;\n k=5, ratios=1.0, knn_tree=\"Brute\", rng=default_rng(),\n try_preserve_type=true\n)\n\nDescription\n\nOversamples a dataset using SMOTE-NC (Synthetic Minority Oversampling Techniques-Nominal Continuous) algorithm to correct for class imbalance as presented in [1]. This is a variant of SMOTE to deal with datasets with both nominal and continuous features. \n\nwarning: SMOTE-NC Assumes Continuous Features Exist\nSMOTE-NC will not work if the dataset is purely nominal. In that case, refer to SMOTE-N instead. Meanwhile, if the dataset is purely continuous then it's equivalent to the standard SMOTE.\n\nPositional Arguments\n\nX: A matrix of floats or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\ncat_inds::AbstractVector{<:Int}: A vector of the indices of the nominal features. Supplied only if X is a matrix. Otherwise, they are inferred from the table's scitypes.\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the smallest class.\n\nratios=1.0: A parameter that controls the amount of oversampling to be done for each class\nCan be a float and in this case each class will be oversampled to the size of the majority class times the float. By default, all classes are oversampled to the size of the majority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nknn_tree: Decides the tree used in KNN computations. Either \"Brute\" or \"Ball\". BallTree can be much faster but may lead to inaccurate results.\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nXover: A matrix or table that includes original data and the new observations due to oversampling. depending on whether the input X is a matrix or table respectively\nyover: An abstract vector of labels corresponding to Xover\n\nExample\n\nusing Imbalance\nusing ScientificTypes\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows = 100\nnum_continuous_feats = 3\n# want two categorical features with three and two possible values respectively\nnum_vals_per_category = [3, 2]\n\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, num_vals_per_category, rng=42) \njulia> Imbalance.checkbalance(y)\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (39.6%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (68.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\njulia> ScientificTypes.schema(X).scitypes\n(Continuous, Continuous, Continuous, Continuous, Continuous)\n# coerce nominal columns to a finite scitype (multiclass or ordered factor)\nX = coerce(X, :Column4=>Multiclass, :Column5=>Multiclass)\n\n# apply SMOTE-NC\nXover, yover = smotenc(X, y; k = 5, ratios = Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng = 42)\n\njulia> Imbalance.checkbalance(yover)\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 38 (79.2%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 43 (89.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the SMOTENC model and pass the positional arguments (excluding cat_inds) to the transform method. \n\nusing MLJ\nSMOTENC = @load SMOTENC pkg=Imbalance\n\n# Wrap the model in a machine\noversampler = SMOTENC(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nmach = machine(oversampler)\n\n# Provide the data to transform (there is nothing to fit)\nXover, yover = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser. Note that only Table input is supported by the MLJ interface for this method.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing ScientificTypes\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_continuous_feats = 3\ny_ind = 2\n# generate a table and categorical vector accordingly\nXy, _ = generate_imbalanced_data(num_rows, num_continuous_feats; insert_y=y_ind,\n class_probs= [0.5, 0.2, 0.3], num_vals_per_category=[3, 2],\n rng=42) \n\n# Table must have only finite or continuous scitypes \nXy = coerce(Xy, :Column2=>Multiclass, :Column5=>Multiclass, :Column6=>Multiclass)\n\n# Initiate SMOTENC model\noversampler = SMOTENC(y_ind; k=5, ratios=Dict(1=>1.0, 2=> 0.9, 3=>0.9), rng=42)\nXyover = Xy |> oversampler \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, “SMOTE: synthetic minority over-sampling technique,” Journal of artificial intelligence research, 321-357, 2002.\n\n\n\n\n\n","category":"function"},{"location":"examples/walkthrough/#Introduction","page":"Introduction","title":"Introduction","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In this section of the docs, we will walk you through some examples to demonstrate how you can use Imbalance.jl in your machine learning project. Although we focus on examples, you can learn more about how specific algorithms work by reading this series of blogposts on Medium.","category":"page"},{"location":"examples/walkthrough/#Prerequisites","page":"Introduction","title":"Prerequisites","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In further examples, we will assume familiarity with the CSV, DataFrames, ScientificTypes and MLJ packages, all of which come with excellent documentation. This example is devoted to assuring and enforcing your familiarity with such packages. You can try this all examples in the docs on your browser using Google Colab and you can read more about that in the last section.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \n \"Imbalance\", \"MLJBalancing\", \"ScientificTypes\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing MLJBalancing\nusing ScientificTypes\nusing HTTP: download","category":"page"},{"location":"examples/walkthrough/#Loading-Data","page":"Introduction","title":"Loading Data","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In this example, we will consider the BMI dataset found on Kaggle where the objective is to predict the BMI index of individuals given their gender, weight and height. ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/bmi.csv\", \"./\")\ndf = CSV.read(\"./bmi.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌─────────┬────────┬────────┬───────┐\n│ Gender │ Height │ Weight │ Index │\n│ String7 │ Int64 │ Int64 │ Int64 │\n│ Textual │ Count │ Count │ Count │\n├─────────┼────────┼────────┼───────┤\n│ Male │ 174 │ 96 │ 4 │\n│ Male │ 189 │ 87 │ 2 │\n│ Female │ 185 │ 110 │ 4 │\n│ Female │ 195 │ 104 │ 3 │\n│ Male │ 149 │ 61 │ 3 │\n└─────────┴────────┴────────┴───────┘\n\n\n┌ Warning: Reading one byte at a time from HTTP.Stream is inefficient.\n│ Use: io = BufferedInputStream(http::HTTP.Stream) instead.\n│ See: https://github.com/BioJulia/BufferedStreams.jl\n└ @ HTTP.Streams /Users/essam/.julia/packages/HTTP/SN7VW/src/Streams.jl:240\n┌ Info: Downloading\n│ source = https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/bmi.csv\n│ dest = ./bmi.csv\n│ progress = NaN\n│ time_taken = 0.0 s\n│ time_remaining = NaN s\n│ average_speed = 7.933 MiB/s\n│ downloaded = 8.123 KiB\n│ remaining = ∞ B\n│ total = ∞ B\n└ @ HTTP /Users/essam/.julia/packages/HTTP/SN7VW/src/download.jl:132","category":"page"},{"location":"examples/walkthrough/#Coercing-Data","page":"Introduction","title":"Coercing Data","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Typical models from MLJ assume that elements in each column of a table have some scientific type as defined by the ScientificTypes.jl package. Among the many types defined by the package, we are interested in Multiclass, OrderedFactor which fall under the Finite abstract type and Continuous and Count which fall under the Infinite abstract type.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"One motivation for this package is that it's not generally obvious whether numerical data in an input table is of continuous type or categorical type given that numbers can describe both. Meanwhile, it's problematic if a model treats numerical data as say Continuous or Count when it's in reality nominal (i.e., Multiclass) or ordinal (i.e., OrderedFactor).","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"We can use schema(df) to see how each features is currently going to be interpreted by the resampling algorithms: ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌────────┬──────────┬─────────┐\n│ names │ scitypes │ types │\n├────────┼──────────┼─────────┤\n│ Gender │ Textual │ String7 │\n│ Height │ Count │ Int64 │\n│ Weight │ Count │ Int64 │\n│ Index │ Count │ Int64 │\n└────────┴──────────┴─────────┘","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"To change encodings that are leading to incorrect interpretations (true for all variable in this example), we use the coerce method, as follows:","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"df = coerce(df,\n :Gender => Multiclass,\n :Height => Continuous,\n :Weight => Continuous,\n :Index => OrderedFactor)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌────────┬──────────────────┬───────────────────────────────────┐\n│ names │ scitypes │ types │\n├────────┼──────────────────┼───────────────────────────────────┤\n│ Gender │ Multiclass{2} │ CategoricalValue{String7, UInt32} │\n│ Height │ Continuous │ Float64 │\n│ Weight │ Continuous │ Float64 │\n│ Index │ OrderedFactor{6} │ CategoricalValue{Int64, UInt32} │\n└────────┴──────────────────┴───────────────────────────────────┘","category":"page"},{"location":"examples/walkthrough/#Unpacking-and-Splitting-Data","page":"Introduction","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"y, X = unpack(df, ==(:Index); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌───────────────────────────────────┬────────────┬────────────┐\n│ Gender │ Height │ Weight │\n│ CategoricalValue{String7, UInt32} │ Float64 │ Float64 │\n│ Multiclass{2} │ Continuous │ Continuous │\n├───────────────────────────────────┼────────────┼────────────┤\n│ Female │ 173.0 │ 82.0 │\n│ Female │ 187.0 │ 121.0 │\n│ Male │ 144.0 │ 145.0 │\n│ Male │ 156.0 │ 74.0 │\n│ Male │ 167.0 │ 151.0 │\n└───────────────────────────────────┴────────────┴────────────┘","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Splitting the data into train and test portions is also easy using MLJ's partition function. stratify=y guarantees that the data is distributed in the same proportions as the original dataset in both splits which is more representative of the real world.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"(X_train, X_test), (y_train, y_test) = partition(\n\t(X, y),\n\t0.8,\n\tmulti = true,\n\tshuffle = true,\n\tstratify = y,\n\trng = Random.Xoshiro(42)\n)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"((399×3 DataFrame\n Row │ Gender Height Weight \n │ Cat… Float64 Float64 \n─────┼──────────────────────────\n 1 │ Female 179.0 150.0\n 2 │ Male 141.0 80.0\n 3 │ Male 179.0 152.0\n 4 │ Male 187.0 138.0\n 5 │ Male 148.0 155.0\n 6 │ Female 192.0 101.0\n 7 │ Male 145.0 78.0\n 8 │ Female 162.0 159.0\n ⋮ │ ⋮ ⋮ ⋮\n 393 │ Female 161.0 154.0\n 394 │ Female 172.0 109.0\n 395 │ Female 163.0 159.0\n 396 │ Female 186.0 146.0\n 397 │ Male 194.0 106.0\n 398 │ Female 167.0 153.0\n 399 │ Female 162.0 64.0\n 384 rows omitted, 101×3 DataFrame\n Row │ Gender Height Weight \n │ Cat… Float64 Float64 \n─────┼──────────────────────────\n 1 │ Female 157.0 56.0\n 2 │ Male 180.0 75.0\n 3 │ Female 157.0 110.0\n 4 │ Female 182.0 143.0\n 5 │ Male 165.0 104.0\n 6 │ Male 182.0 73.0\n 7 │ Male 165.0 68.0\n 8 │ Male 166.0 107.0\n ⋮ │ ⋮ ⋮ ⋮\n 95 │ Male 163.0 137.0\n 96 │ Female 188.0 99.0\n 97 │ Female 146.0 123.0\n 98 │ Male 186.0 68.0\n 99 │ Female 140.0 76.0\n 100 │ Female 168.0 139.0\n 101 │ Male 180.0 149.0\n 86 rows omitted), (CategoricalArrays.CategoricalValue{Int64, UInt32}[5, 5, 5, 4, 5, 3, 4, 5, 5, 5 … 5, 4, 4, 5, 4, 5, 5, 3, 5, 2], CategoricalArrays.CategoricalValue{Int64, UInt32}[2, 2, 5, 5, 4, 2, 2, 4, 3, 3 … 2, 0, 0, 5, 3, 5, 2, 4, 5, 5]))","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"⚠️ Always split the data before oversampling. If your test data has oversampled observations then train-test contamination has occurred; novel observations will not come from the oversampling function.","category":"page"},{"location":"examples/walkthrough/#Oversampling","page":"Introduction","title":"Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0: ▇▇▇ 13 (6.6%) \n1: ▇▇▇▇▇▇ 22 (11.1%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 68 (34.3%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 69 (34.8%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 130 (65.7%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 198 (100.0%)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Looks like we have a class imbalance problem. Let's set the desired ratios so that the first two classes are 30% of the majority class, the second two are 50% of the majority class and the rest as is (ignore in the dictionary)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"ratios = Dict(0=>0.3, 1=>0.3, 2=>0.5, 3=>0.5) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Dict{Int64, Float64} with 4 entries:\n 0 => 0.3\n 2 => 0.5\n 3 => 0.5\n 1 => 0.3","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Let's use random oversampling to oversample the data. This particular model does not care about the scientific types of the data. It takes X and y as positional arguments and ratios and rng are the main keyword arguments","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Xover, yover = random_oversample(X_train, y_train; ratios, rng=42) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"(514×3 DataFrame\n Row │ Gender Height Weight \n │ Cat… Float64 Float64 \n─────┼──────────────────────────\n 1 │ Female 179.0 150.0\n 2 │ Male 141.0 80.0\n 3 │ Male 179.0 152.0\n 4 │ Male 187.0 138.0\n 5 │ Male 148.0 155.0\n 6 │ Female 192.0 101.0\n 7 │ Male 145.0 78.0\n 8 │ Female 162.0 159.0\n ⋮ │ ⋮ ⋮ ⋮\n 508 │ Female 196.0 50.0\n 509 │ Male 193.0 54.0\n 510 │ Male 182.0 50.0\n 511 │ Male 190.0 50.0\n 512 │ Male 190.0 50.0\n 513 │ Male 198.0 50.0\n 514 │ Male 198.0 50.0\n 499 rows omitted, CategoricalArrays.CategoricalValue{Int64, UInt32}[5, 5, 5, 4, 5, 3, 4, 5, 5, 5 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"checkbalance(yover)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 47 (29.7%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 47 (29.7%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 79 (50.0%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 79 (50.0%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 104 (65.8%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 158 (100.0%)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"This indeeds aligns with the desired ratios we have set earlier.","category":"page"},{"location":"examples/walkthrough/#Training-the-Model","page":"Introduction","title":"Training the Model","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"5-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Let's go for a decision tree form BetaML","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"import Pkg; Pkg.add(\"BetaML\")","category":"page"},{"location":"examples/walkthrough/#Before-Oversampling","page":"Introduction","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"# 1. Load the model\nDecisionTreeClassifier = @load DecisionTreeClassifier pkg=BetaML verbosity=0\n\n# 2. Instantiate it\nmodel = DecisionTreeClassifier(max_depth=5, rng=Random.Xoshiro(42))\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌ Info: Training machine(DecisionTreeClassifier(max_depth = 5, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 5, …)\n args: \n 1:\tSource @027 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{2}}}}\n 2:\tSource @092 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/walkthrough/#After-Oversampling","page":"Introduction","title":"After Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"┌ Info: Training machine(DecisionTreeClassifier(max_depth = 5, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 5, …)\n args: \n 1:\tSource @592 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{2}}}}\n 2:\tSource @711 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/walkthrough/#Evaluating-the-Model","page":"Introduction","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"To evaluate the model, we will use the balanced accuracy metric which equally account for all classes. For instance, if we have two classes and we correctly classify 100% of the examples in the first and 50% of the examples in the second then the balanced accuracy is (100+50)2=75. This holds regardless to how big or small each class is.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"The predict_mode will return a vector of predictions given X_test and the fitted machine. It's different in that predict in not returning probablities the model assigns to each class; instead, it returns the classes with the maximum probabilities; i.e., the modes.","category":"page"},{"location":"examples/walkthrough/#Before-Oversampling-2","page":"Introduction","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"y_pred = predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0.62","category":"page"},{"location":"examples/walkthrough/#After-Oversampling-2","page":"Introduction","title":"After Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"y_pred_over = predict_mode(mach_over, X_test)\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"0.75","category":"page"},{"location":"examples/walkthrough/#Evaluating-the-Model-Revisited","page":"Introduction","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a 13% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/walkthrough/#Before-Oversampling-3","page":"Introduction","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.621 │ 0.0913 │ [0.593, 0.473, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Under the normality assumption, the 95% confidence interval is 62.1±9.13% which is pretty big. Let's see how it looks after oversampling.","category":"page"},{"location":"examples/walkthrough/#After-Oversampling-3","page":"Introduction","title":"After Oversampling","text":"","category":"section"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.RandomOversampler(ratios=ratios, rng=42)\nmodel = DecisionTreeClassifier(max_depth=5, rng=Random.Xoshiro(42))\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = DecisionTreeClassifier(max_depth = 5, …), …)\n args: \n 1:\tSource @099 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{2}}}}\n 2:\tSource @071 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"predict_mode(mach_over, X_test) == y_pred_over","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"true","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.7 │ 0.0717 │ [0.7, 0.536, 0. ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/walkthrough/","page":"Introduction","title":"Introduction","text":"This results in an interval 70±7.2% which can be viewed as a reasonable improvement over 62.1±9.13%. The uncertainty in the intervals can be explained by the fact that the dataset is small with many classes.","category":"page"},{"location":"contributing/#Directory-Structure","page":"Contributing","title":"Directory Structure","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"The folder structure is as follows:","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":".\n├── Imbalance.jl # entry point to package\n├── generic_resample.jl # functions used in all resampling methods\n├── generic_encoder.jl # used in all resampling methods that deal with categorical data\n├── table_wrappers.jl # generalizes a function that operates on matrices to tables\n├── class_counts.jl # used to compute number of data points to add or remove\n├── common # has julia files for common docs, error strings and utils\n├── distance_metrics # has distance metrics used by some resampling methods\n├── oversampling_methods # all oversampling methods live here\n├── undersampling_methods # all undersampling methods live here\n└── extras.jl # extra functions like generating data or checking balance","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"The purpose of each file is further documented therein at the beginning of the file. The files are ordered here in the recommended order of checking. ","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Any method resampling method implemented in the oversampling_methods or undersampling_methods folder takes the following structure:","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"├── resample_method # contains implementation and interfaces for a resampling method\n│ ├── interface_mlj.jl # implements MLJ interface for the method\n│ ├── interface_tables.jl # implements Tables.jl interface for the method\n│ └── resample_method.jl # implements the method itself (pure functional interface)","category":"page"},{"location":"contributing/#Contribution","page":"Contributing","title":"Contribution","text":"","category":"section"},{"location":"contributing/#Reporting-Problems-or-Seeking-Support","page":"Contributing","title":"Reporting Problems or Seeking Support","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Do not hesitate to post a Github issue with your question or problem.","category":"page"},{"location":"contributing/#Adding-New-Resampling-Methods","page":"Contributing","title":"Adding New Resampling Methods","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Make a new folder resample_method for the method in the oversampling_methods or undersampling_methods\nImplement in resample_method/resample_method.jl the method over matrices for one minority class\nUse generic_oversample.jl to generalize it to work on the whole data\nUse table_wrapper.jl to generalize the method to work on tables and possibly use generic_encoder.jl\nImplement the MLJ interface for the method in resample_method/interface_mlj\nImplement the TableTransforms interface for the method in resample_method/interface_tables.jl\nUse the rest of the files according to their description\nTesting and documentation should be done in parallel","category":"page"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Surely, you can ignore ignore the third step if the algorithm you are implementing does not operate in \"per-class\" sense.","category":"page"},{"location":"contributing/#Hot-algorithms-to-add","page":"Contributing","title":"🔥 Hot algorithms to add","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"K-Means SMOTE: Takes care of where exactly to generate more points using SMOTE by factoring in \"within class imbalance\". This may be also easily generalized to algorithms beyond SMOTE.\nCondensedNearestNeighbors: Undersamples the dataset such as to perserve the decision boundary by KNN\nBorderlineSMOTE2: A small modification of the BorderlineSMOTE1 condition\nRepeatedENNUndersampler: Simply repeats ENNUndersampler multiple times","category":"page"},{"location":"contributing/#Adding-New-Tutorials","page":"Contributing","title":"Adding New Tutorials","text":"","category":"section"},{"location":"contributing/","page":"Contributing","title":"Contributing","text":"Make a new notebook with the tutorial in the examples folder found in docs/src/examples\nRun the notebook so that the output is shown below each cell\nIf the notebook produces visuals then save and load them in the notebook\nConvert it to markdown by using Python to run from convert import convert_to_md; convert_to_md('')\nSet a title, description, image and links for it in the dictionary found in docs/examples.jl\nFor the colab link, you do not need to upload anything just follow the link pattern in the file","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Effect-of-ENN-Hyperparameters","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \n \"ScientificTypes\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing ScientificTypes\nusing Plots, Measures\nusing HTTP: download","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Loading-Data","page":"Effect of ENN Hyperparameters","title":"Loading Data","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"In this example, we will consider the BMI dataset found on Kaggle where the objective is to predict the BMI index of individuals given their gender, weight and height. ","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/effect_of_k_enn/bmi.csv\", \"./\")\n\ndf = CSV.read(\"./bmi.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌─────────┬────────┬────────┬───────┐\n│ Gender │ Height │ Weight │ Index │\n│ String7 │ Int64 │ Int64 │ Int64 │\n│ Textual │ Count │ Count │ Count │\n├─────────┼────────┼────────┼───────┤\n│ Male │ 174 │ 96 │ 4 │\n│ Male │ 189 │ 87 │ 2 │\n│ Female │ 185 │ 110 │ 4 │\n│ Female │ 195 │ 104 │ 3 │\n│ Male │ 149 │ 61 │ 3 │\n└─────────┴────────┴────────┴───────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"We will drop the gender attribute for purposes of visualization and to have more options for the model.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"select!(df, Not(:Gender)) |> pretty","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Coercing-Data","page":"Effect of ENN Hyperparameters","title":"Coercing Data","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌────────┬──────────┬───────┐\n│ names │ scitypes │ types │\n├────────┼──────────┼───────┤\n│ Height │ Count │ Int64 │\n│ Weight │ Count │ Int64 │\n│ Index │ Count │ Int64 │\n└────────┴──────────┴───────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Weight and Height should be Continuous and Index should be an OrderedFactor","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"df = coerce(df,\n :Height => Continuous,\n :Weight => Continuous,\n :Index => OrderedFactor)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌────────┬──────────────────┬─────────────────────────────────┐\n│ names │ scitypes │ types │\n├────────┼──────────────────┼─────────────────────────────────┤\n│ Height │ Continuous │ Float64 │\n│ Weight │ Continuous │ Float64 │\n│ Index │ OrderedFactor{6} │ CategoricalValue{Int64, UInt32} │\n└────────┴──────────────────┴─────────────────────────────────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Unpacking-Data","page":"Effect of ENN Hyperparameters","title":"Unpacking Data","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"y, X = unpack(df, ==(:Index); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌────────────┬────────────┐\n│ Height │ Weight │\n│ Float64 │ Float64 │\n│ Continuous │ Continuous │\n├────────────┼────────────┤\n│ 173.0 │ 82.0 │\n│ 187.0 │ 121.0 │\n│ 144.0 │ 145.0 │\n│ 156.0 │ 74.0 │\n│ 167.0 │ 151.0 │\n└────────────┴────────────┘","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"We will skip splitting the data since the main purpose of this tutorial is visualization.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Undersampling","page":"Effect of ENN Hyperparameters","title":"Undersampling","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Before undersampling, let's check the balance of the data","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"checkbalance(y; ref=\"minority\")","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"0: ▇▇▇ 13 (100.0%) \n1: ▇▇▇▇▇▇ 22 (169.2%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 68 (523.1%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 69 (530.8%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 130 (1000.0%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 198 (1523.1%)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Let's use ENN undersampling to undersample the data. ENN undersamples the data by \"cleaning it out\" or in another words deleting any point that violates a certain condition. We can limit the number of points that are deleted by setting the min_ratios parameter. ","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"We will set k=1 and keep_condition=\"only mode\" which means that any point with a label that is not the only most common one amongst its 1-nearest neighbors will be deleted (i.e., must have same label as its nearest neighbor). By setting min_ratios=1.0 we constraint that points should never be deleted form any class if it's ratio relative to the minority class will be less than 1.0. This also means that no points will be deleted from the minority class.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"X_under, y_under = enn_undersample(\n\tX,\n\ty;\n\tk = 1,\n\tkeep_condition = \"only mode\",\n\tmin_ratios=0.01,\n\trng = 42,\n)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"(448×2 DataFrame\n Row │ Height Weight \n │ Float64 Float64 \n─────┼──────────────────\n 1 │ 173.0 82.0\n 2 │ 182.0 70.0\n 3 │ 156.0 52.0\n 4 │ 172.0 67.0\n 5 │ 162.0 58.0\n 6 │ 180.0 75.0\n 7 │ 190.0 83.0\n 8 │ 195.0 81.0\n ⋮ │ ⋮ ⋮\n 442 │ 196.0 50.0\n 443 │ 191.0 54.0\n 444 │ 185.0 52.0\n 445 │ 182.0 50.0\n 446 │ 198.0 50.0\n 447 │ 198.0 50.0\n 448 │ 181.0 51.0\n 433 rows omitted, CategoricalArrays.CategoricalValue{Int64, UInt32}[2, 2, 2, 2, 2, 2, 2, 2, 2, 2 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"checkbalance(y_under; ref=\"minority\")","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"0: ▇▇▇ 11 (100.0%) \n1: ▇▇▇▇▇ 19 (172.7%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 56 (509.1%) \n3: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 58 (527.3%) \n4: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 115 (1045.5%) \n5: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 189 (1718.2%)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"This indeeds aligns with the desired ratios we have set earlier.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Training-the-Model","page":"Effect of ENN Hyperparameters","title":"Training the Model","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"models(matching(X_under, y_under))","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"53-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Let's go for an SVM from LIBSVM","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"import Pkg; Pkg.add(\"LIBSVM\")\nimport LIBSVM;","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":" Updating registry at `~/.julia/registries/General`\n Updating git-repo `https://github.com/JuliaRegistries/General.git`\n Resolving package versions...\n No Changes to `~/Documents/GitHub/Imbalance.jl/docs/Project.toml`\n No Changes to `~/Documents/GitHub/Imbalance.jl/docs/Manifest.toml`","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Before-Undersampling","page":"Effect of ENN Hyperparameters","title":"Before Undersampling","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"# 1. Load the model\nSVC = @load SVC pkg=LIBSVM\n\n# 2. Instantiate it\nmodel = SVC(kernel=LIBSVM.Kernel.RadialBasis, gamma=0.01) ## instance\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X, y)\n\n# 4. fit the machine learning model\nfit!(mach)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"import MLJLIBSVMInterface ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @987 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @104 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#After-Undersampling","page":"Effect of ENN Hyperparameters","title":"After Undersampling","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"# 3. Wrap it with the data in a machine\nmach_under = machine(model, X_under, y_under)\n\n# 4. fit the machine learning model\nfit!(mach_under)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @123 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @423 ⏎ AbstractVector{OrderedFactor{6}}","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Plot-Decision-Boundaries","page":"Effect of ENN Hyperparameters","title":"Plot Decision Boundaries","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Construct ranges for each feature and consecutively a grid","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"height_range =\n\trange(minimum(X.Height) - 1, maximum(X.Height) + 1, length = 400)\nweight_range =\nrange(minimum(X.Weight) - 1, maximum(X.Weight) + 1, length = 400)\ngrid_points = [(h, w) for h in height_range, w in weight_range]","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"400×400 Matrix{Tuple{Float64, Float64}}:\n (139.0, 49.0) (139.0, 49.2807) (139.0, 49.5614) … (139.0, 161.0)\n (139.153, 49.0) (139.153, 49.2807) (139.153, 49.5614) (139.153, 161.0)\n (139.306, 49.0) (139.306, 49.2807) (139.306, 49.5614) (139.306, 161.0)\n (139.459, 49.0) (139.459, 49.2807) (139.459, 49.5614) (139.459, 161.0)\n (139.612, 49.0) (139.612, 49.2807) (139.612, 49.5614) (139.612, 161.0)\n (139.764, 49.0) (139.764, 49.2807) (139.764, 49.5614) … (139.764, 161.0)\n (139.917, 49.0) (139.917, 49.2807) (139.917, 49.5614) (139.917, 161.0)\n (140.07, 49.0) (140.07, 49.2807) (140.07, 49.5614) (140.07, 161.0)\n (140.223, 49.0) (140.223, 49.2807) (140.223, 49.5614) (140.223, 161.0)\n (140.376, 49.0) (140.376, 49.2807) (140.376, 49.5614) (140.376, 161.0)\n ⋮ ⋱ \n (198.777, 49.0) (198.777, 49.2807) (198.777, 49.5614) (198.777, 161.0)\n (198.93, 49.0) (198.93, 49.2807) (198.93, 49.5614) (198.93, 161.0)\n (199.083, 49.0) (199.083, 49.2807) (199.083, 49.5614) (199.083, 161.0)\n (199.236, 49.0) (199.236, 49.2807) (199.236, 49.5614) (199.236, 161.0)\n (199.388, 49.0) (199.388, 49.2807) (199.388, 49.5614) … (199.388, 161.0)\n (199.541, 49.0) (199.541, 49.2807) (199.541, 49.5614) (199.541, 161.0)\n (199.694, 49.0) (199.694, 49.2807) (199.694, 49.5614) (199.694, 161.0)\n (199.847, 49.0) (199.847, 49.2807) (199.847, 49.5614) (199.847, 161.0)\n (200.0, 49.0) (200.0, 49.2807) (200.0, 49.5614) (200.0, 161.0)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Evaluate the grid with the machine before and after undersampling","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"grid_predictions =[\n predict(mach, Tables.table(reshape(collect(point), 1, 2)))[1] for\n \tpoint in grid_points\n ]\n \ngrid_predictions_under = [\n predict(mach_under, Tables.table(reshape(collect(point), 1, 2)))[1] for\n point in grid_points\n]","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Make two contour plots using the grid predictions before and after oversampling","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"colors = [:green, :aqua, :violet, :red, :blue, :yellow]\np = contourf(weight_range, height_range, grid_predictions,\nlevels = 6, color = colors, colorbar = false)\np_under = contourf(weight_range, height_range, grid_predictions_under,\nlevels = 6, color = colors, colorbar = false)\nprintln()","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"labels = unique(y)\ncolors = Dict(\n\t0 => \"green\",\n\t1 => \"cyan3\",\n\t2 => \"violet\",\n\t3 => \"red\",\n\t4 => \"dodgerblue\",\n\t5 => \"gold2\",\n)\n\nfor label in labels\n\tscatter!(p, X.Weight[y.==label], X.Height[y.==label],\n\t\tcolor = colors[label], label = label, markerstrokewidth = 1.5,\n\t\ttitle = \"Before Undersampling\")\n\tscatter!(p_under, X_under.Weight[y_under.==label], X_under.Height[y_under.==label],\n\t\tcolor = colors[label], label = label, markerstrokewidth = 1.5,\n\t\ttitle = \"After Undersampling\")\nend\n\nplot_res = plot(\n\tp,\n\tp_under,\n\tlayout = (1, 2),\n\txlabel = \"Height\",\n\tylabel = \"Width\",\n\tsize = (1200, 450),\n\tmargin = 5mm, dpi = 200,\n\tlegend = :outerbottomright,\n)\nsavefig(plot_res, \"./assets/ENN-before-after.png\")\n","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"(Image: enn comparison)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/#Effect-of-k-Hyperparameter","page":"Effect of ENN Hyperparameters","title":"Effect of k Hyperparameter","text":"","category":"section"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"Now let's study the cleaning effect as k increases for all types of keep conditions of undersampling.","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"anim = @animate for k ∈ 1:15\n\tconditions = [\"exists\", \"mode\", \"only mode\", \"all\"]\n\tplots = [plot() for _ in 1:4]\n\tdata_list = []\n\n\tfor i in 1:4\n\n\t\tX_under, y_under = enn_undersample(\n\t\t\tX,\n\t\t\ty;\n\t\t\tk = k,\n\t\t\tkeep_condition = conditions[i],\n\t\t\tmin_ratios = 0.01,\n\t\t\trng = 42,\n\t\t)\n\n\t\t# fit machine\n\t\tmach_under = machine(model, X_under, y_under)\n\t\tfit!(mach_under, verbosity = 0)\n\n\t\t# grid predictions\n\t\tgrid_predictions_under = [\n\t\t\tpredict(mach_under, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\t\t\tpoint in grid_points\n\t\t]\n\n\t\t# plot\n\t\tcolors = [:green, :aqua, :violet, :red, :blue, :yellow]\n\t\tcontourf!(plots[i], weight_range, height_range, grid_predictions_under,\n\t\t\tlevels = 6, color = colors, colorbar = false)\n\n\t\tcolors = Dict(\n\t\t\t0 => \"green\",\n\t\t\t1 => \"cyan3\",\n\t\t\t2 => \"violet\",\n\t\t\t3 => \"red\",\n\t\t\t4 => \"dodgerblue\",\n\t\t\t5 => \"gold2\",\n\t\t)\n\t\tfor label in labels\n\t\t\tscatter!(plots[i], X_under.Weight[y_under.==label],\n\t\t\t\tX_under.Height[y_under.==label],\n\t\t\t\tcolor = colors[label], label = label, markerstrokewidth = 1.5,\n\t\t\t\ttitle = \"$(conditions[i])\", legend = ((i == 2) ? :bottomright : :none))\n\t\tend\n\t\tplot!(\n\t\t\tplots[1], plots[2], plots[3], plots[4],\n\t\t\tlayout = (1, 4),\n\t\t\tsize = (1300, 420),\n\t\t\tplot_title = \"Undersampling with k =$k\",\n\t\t)\n\tend\n\tplot!(dpi = 150)\nend\n","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"gif(anim, \"./assets/enn-k-animation.gif\", fps=1)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"(Image: enn-gif-hyperparameter)","category":"page"},{"location":"examples/effect_of_k_enn/effect_of_k_enn/","page":"Effect of ENN Hyperparameters","title":"Effect of ENN Hyperparameters","text":"As we can see, the most constraining condition is all. It deletes any point where the label is different than any of the nearest k neighbors which also explains why it's the most sensitive to the hyperparameter k.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#SMOTEN-on-Mushroom-Data","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing MLJBalancing\nusing StatsBase\nusing ScientificTypes\nusing Plots\nusing HTTP: download","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Loading-Data","page":"SMOTEN on Mushroom Data","title":"Loading Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"In this example, we will consider the Mushroom dataset found on Kaggle for the objective of predicting mushroom odour given various features about the mushroom.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/smoten_mushroom/mushrooms.csv\", \"./\")\ndf = CSV.read(\"./mushrooms.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌─────────┬───────────┬─────────────┬───────────┬─────────┬─────────┬─────────────────┬──────────────┬───────────┬────────────┬─────────────┬────────────┬──────────────────────────┬──────────────────────────┬────────────────────────┬────────────────────────┬───────────┬────────────┬─────────────┬───────────┬───────────────────┬────────────┬─────────┐\n│ class │ cap-shape │ cap-surface │ cap-color │ bruises │ odor │ gill-attachment │ gill-spacing │ gill-size │ gill-color │ stalk-shape │ stalk-root │ stalk-surface-above-ring │ stalk-surface-below-ring │ stalk-color-above-ring │ stalk-color-below-ring │ veil-type │ veil-color │ ring-number │ ring-type │ spore-print-color │ population │ habitat │\n│ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │ String1 │\n│ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │ Textual │\n├─────────┼───────────┼─────────────┼───────────┼─────────┼─────────┼─────────────────┼──────────────┼───────────┼────────────┼─────────────┼────────────┼──────────────────────────┼──────────────────────────┼────────────────────────┼────────────────────────┼───────────┼────────────┼─────────────┼───────────┼───────────────────┼────────────┼─────────┤\n│ p │ x │ s │ n │ t │ p │ f │ c │ n │ k │ e │ e │ s │ s │ w │ w │ p │ w │ o │ p │ k │ s │ u │\n│ e │ x │ s │ y │ t │ a │ f │ c │ b │ k │ e │ c │ s │ s │ w │ w │ p │ w │ o │ p │ n │ n │ g │\n│ e │ b │ s │ w │ t │ l │ f │ c │ b │ n │ e │ c │ s │ s │ w │ w │ p │ w │ o │ p │ n │ n │ m │\n│ p │ x │ y │ w │ t │ p │ f │ c │ n │ n │ e │ e │ s │ s │ w │ w │ p │ w │ o │ p │ k │ s │ u │\n│ e │ x │ s │ g │ f │ n │ f │ w │ b │ k │ t │ e │ s │ s │ w │ w │ p │ w │ o │ e │ n │ a │ g │\n└─────────┴───────────┴─────────────┴───────────┴─────────┴─────────┴─────────────────┴──────────────┴───────────┴────────────┴─────────────┴────────────┴──────────────────────────┴──────────────────────────┴────────────────────────┴────────────────────────┴───────────┴────────────┴─────────────┴───────────┴───────────────────┴────────────┴─────────┘","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Visualize-the-Data","page":"SMOTEN on Mushroom Data","title":"Visualize the Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Since this dataset is composed only of categorical features, a bar chart for each column is a good way to visualize the data.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# Create a bar chart for each column\nbar_charts = []\nfor col in names(df)\n counts = countmap(df[!, col])\n k, v = collect(keys(counts)), collect(values(counts))\n if length(k) < 20\n push!(bar_charts, bar(k, v, legend=false, title=col))\n end\nend\n\n# Combine bar charts into a grid layout with specified plot size\nplot_res = plot(bar_charts..., layout=(5, 5), \n size=(1300, 1200), \n plot_title=\"Value Frequencies for each Categorical Variable\")\nsavefig(plot_res, \"./assets/mushroom-bar-charts.png\")","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"(Image: Mushroom Features Plots)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We will take the mushroom odour as our target and all the rest as features. ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Coercing-Data","page":"SMOTEN on Mushroom Data","title":"Coercing Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Typical models from MLJ assume that elements in each column of a table have some scientific type as defined by the ScientificTypes.jl package. It's often necessary to coerce the types inferred by default to the appropriate type.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌──────────────────────────┬──────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────────────────┼──────────┼─────────┤\n│ class │ Textual │ String1 │\n│ cap-shape │ Textual │ String1 │\n│ cap-surface │ Textual │ String1 │\n│ cap-color │ Textual │ String1 │\n│ bruises │ Textual │ String1 │\n│ odor │ Textual │ String1 │\n│ gill-attachment │ Textual │ String1 │\n│ gill-spacing │ Textual │ String1 │\n│ gill-size │ Textual │ String1 │\n│ gill-color │ Textual │ String1 │\n│ stalk-shape │ Textual │ String1 │\n│ stalk-root │ Textual │ String1 │\n│ stalk-surface-above-ring │ Textual │ String1 │\n│ stalk-surface-below-ring │ Textual │ String1 │\n│ stalk-color-above-ring │ Textual │ String1 │\n│ stalk-color-below-ring │ Textual │ String1 │\n│ ⋮ │ ⋮ │ ⋮ │\n└──────────────────────────┴──────────┴─────────┘\n 7 rows omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"For instance, here we need to coerce all the data to Multiclass as they are all nominal variables. Textual would be the right type for natural language processing models. Instead of typing in each column manually, autotype lets us perform mass conversion using pre-defined rules.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"df = coerce(df, autotype(df, :few_to_finite))\nScientificTypes.schema(df)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌──────────────────────────┬────────────────┬───────────────────────────────────\n│ names │ scitypes │ types ⋯\n├──────────────────────────┼────────────────┼───────────────────────────────────\n│ class │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ cap-shape │ Multiclass{6} │ CategoricalValue{String1, UInt32 ⋯\n│ cap-surface │ Multiclass{4} │ CategoricalValue{String1, UInt32 ⋯\n│ cap-color │ Multiclass{10} │ CategoricalValue{String1, UInt32 ⋯\n│ bruises │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ odor │ Multiclass{9} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-attachment │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-spacing │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-size │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ gill-color │ Multiclass{12} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-shape │ Multiclass{2} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-root │ Multiclass{5} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-surface-above-ring │ Multiclass{4} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-surface-below-ring │ Multiclass{4} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-color-above-ring │ Multiclass{9} │ CategoricalValue{String1, UInt32 ⋯\n│ stalk-color-below-ring │ Multiclass{9} │ CategoricalValue{String1, UInt32 ⋯\n│ ⋮ │ ⋮ │ ⋮ ⋱\n└──────────────────────────┴────────────────┴───────────────────────────────────\n 1 column and 7 rows omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Unpacking-and-Splitting-Data","page":"SMOTEN on Mushroom Data","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"y, X = unpack(df, ==(:odor); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"┌───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┐\n│ class │ cap-shape │ cap-surface │ cap-color │ bruises │ gill-attachment │ gill-spacing │ gill-size │ gill-color │ stalk-shape │ stalk-root │ stalk-surface-above-ring │ stalk-surface-below-ring │ stalk-color-above-ring │ stalk-color-below-ring │ veil-type │ veil-color │ ring-number │ ring-type │ spore-print-color │ population │ habitat │\n│ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │ CategoricalValue{String1, UInt32} │\n│ Multiclass{2} │ Multiclass{6} │ Multiclass{4} │ Multiclass{10} │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{12} │ Multiclass{2} │ Multiclass{5} │ Multiclass{4} │ Multiclass{4} │ Multiclass{9} │ Multiclass{9} │ Multiclass{1} │ Multiclass{4} │ Multiclass{3} │ Multiclass{5} │ Multiclass{9} │ Multiclass{6} │ Multiclass{7} │\n├───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┤\n│ e │ f │ f │ n │ t │ f │ c │ b │ w │ t │ b │ s │ s │ g │ g │ p │ w │ o │ p │ k │ v │ d │\n│ e │ f │ f │ n │ t │ f │ c │ b │ w │ t │ b │ s │ s │ w │ p │ p │ w │ o │ p │ n │ y │ d │\n│ e │ b │ s │ y │ t │ f │ c │ b │ k │ e │ c │ s │ s │ w │ w │ p │ w │ o │ p │ k │ s │ g │\n│ p │ f │ y │ e │ f │ f │ c │ b │ w │ e │ c │ k │ y │ c │ c │ p │ w │ n │ n │ w │ c │ d │\n│ e │ x │ y │ n │ f │ f │ w │ n │ w │ e │ b │ f │ f │ w │ n │ p │ w │ o │ e │ w │ v │ l │\n└───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┘","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Splitting the data into train and test portions is also easy using MLJ's partition function. stratify=y guarantees that the data is distributed in the same proportions as the original dataset in both splits which is more representative of the real world.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"train_inds, test_inds = partition(eachindex(y), 0.8, shuffle=true, stratify=y, rng=Random.Xoshiro(42))\nX_train, X_test = X[train_inds, :], X[test_inds, :]\ny_train, y_test = y[train_inds], y[test_inds]","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"(CategoricalArrays.CategoricalValue{String1, UInt32}[String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"n\"), String1(\"n\"), String1(\"n\") … String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\"), String1(\"f\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\"), String1(\"s\")], CategoricalArrays.CategoricalValue{String1, UInt32}[String1(\"f\"), String1(\"y\"), String1(\"a\"), String1(\"c\"), String1(\"f\"), String1(\"n\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\") … String1(\"f\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"f\"), String1(\"y\"), String1(\"f\"), String1(\"n\"), String1(\"n\"), String1(\"n\")])","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"⚠️ Always split the data before oversampling. If your test data has oversampled observations then train-test contamination has occurred; novel observations will not come from the oversampling function.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Oversampling","page":"SMOTEN on Mushroom Data","title":"Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"It was obvious from the bar charts that there is a severe imbalance problem. Let's look at that again.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"m: ▇ 36 (1.0%) \nc: ▇▇▇ 192 (5.4%) \np: ▇▇▇▇ 256 (7.3%) \na: ▇▇▇▇▇▇ 400 (11.3%) \nl: ▇▇▇▇▇▇ 400 (11.3%) \ny: ▇▇▇▇▇▇▇▇ 576 (16.3%) \ns: ▇▇▇▇▇▇▇▇ 576 (16.3%) \nf: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 2160 (61.2%) \nn: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 3528 (100.0%)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Let's set our desired ratios as follows. these are set relative to the size of the majority class.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"ratios = Dict(\"m\"=>0.3, \n \"c\"=>0.4,\n \"p\"=>0.5,\n \"a\"=>0.5,\n \"l\"=>0.5,\n \"y\"=>0.7,\n \"s\"=>0.7,\n \"f\"=>0.8\n )","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Dict{String, Float64} with 8 entries:\n \"s\" => 0.7\n \"f\" => 0.8\n \"c\" => 0.4\n \"m\" => 0.3\n \"l\" => 0.5\n \"a\" => 0.5\n \"p\" => 0.5\n \"y\" => 0.7","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We have used gut feeling to set them here but usually this is one of the most important hyperparameters to tune over. ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"The easy option ratios=1.0 always exists and would mean that we want to oversample data in each class so that they all match the majority class. It may or may not be the most optimal due to overfitting problems.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Xover, yover = smoten(X_train, y_train; k=2, ratios=ratios, rng=Random.Xoshiro(42))","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Progress: 22%|█████████▏ | ETA: 0:00:01\u001b[K\n\u001b[A\nProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[K\n\u001b[A\n\n\n(15239×22 DataFrame\n Row │ class cap-shape cap-surface cap-color bruises gill-attachment g ⋯\n │ Cat… Cat… Cat… Cat… Cat… Cat… C ⋯\n───────┼────────────────────────────────────────────────────────────────────────\n 1 │ p f s e f f c ⋯\n 2 │ p f y e f f c\n 3 │ e f f w f f w\n 4 │ p f s e f f c\n 5 │ p f y e f f c ⋯\n 6 │ e s f g f f c\n 7 │ p f s n f f c\n 8 │ e x y g t f c\n ⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱\n 15233 │ p x y c f a c ⋯\n 15234 │ p x y e f a c\n 15235 │ p x y n f a c\n 15236 │ p k y c f f c\n 15237 │ p x y c f a c ⋯\n 15238 │ p k y c f f c\n 15239 │ p x y e f f c\n 16 columns and 15224 rows omitted, CategoricalArrays.CategoricalValue{String1, UInt32}[String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"s\"), String1(\"n\"), String1(\"s\"), String1(\"n\"), String1(\"n\"), String1(\"n\") … String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\"), String1(\"m\")])","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"SMOTEN uses a very specialized distance metric to decide the nearest neighbors which explains why it may be a bit slow as it's nontrivial to optimize KNN over such metric.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Now let's check the balance of the data","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"checkbalance(yover)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"m: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 847 (30.0%) \nc: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1129 (40.0%) \na: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1411 (50.0%) \nl: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1411 (50.0%) \np: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1411 (50.0%) \ny: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1975 (70.0%) \ns: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 1975 (70.0%) \nf: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 2258 (80.0%) \nn: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 2822 (100.0%)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Training-the-Model","page":"SMOTEN on Mushroom Data","title":"Training the Model","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"ms = models(matching(Xover, yover))","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"6-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = OneRuleClassifier, package_name = OneRule, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Let's go for a OneRuleClassifier","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"import Pkg; Pkg.add(\"OneRule\")","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":" Resolving package versions...\n Installed MLJBalancing ─ v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Project.toml`\n [45f359ea] + MLJBalancing v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Manifest.toml`\n [45f359ea] + MLJBalancing v0.1.0\nPrecompiling project...\n ✓ MLJBalancing\n 1 dependency successfully precompiled in 25 seconds. 262 already precompiled.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Before-Oversampling","page":"SMOTEN on Mushroom Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# 1. Load the model\nOneRuleClassifier= @load OneRuleClassifier pkg=OneRule\n\n# 2. Instantiate it\nmodel = OneRuleClassifier()\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"import OneRule ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\n\ntrained Machine; caches model-specific representations of data\n model: OneRuleClassifier()\n args: \n 1:\tSource @978 ⏎ Table{Union{AbstractVector{Multiclass{10}}, AbstractVector{Multiclass{12}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{1}}, AbstractVector{Multiclass{4}}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{9}}, AbstractVector{Multiclass{6}}, AbstractVector{Multiclass{7}}}}\n 2:\tSource @097 ⏎ AbstractVector{Multiclass{9}}","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#After-Oversampling","page":"SMOTEN on Mushroom Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"trained Machine; caches model-specific representations of data\n model: OneRuleClassifier()\n args: \n 1:\tSource @469 ⏎ Table{Union{AbstractVector{Multiclass{10}}, AbstractVector{Multiclass{12}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{1}}, AbstractVector{Multiclass{4}}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{9}}, AbstractVector{Multiclass{6}}, AbstractVector{Multiclass{7}}}}\n 2:\tSource @942 ⏎ AbstractVector{Multiclass{9}}","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Evaluating-the-Model","page":"SMOTEN on Mushroom Data","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"To evaluate the model, we will use the balanced accuracy metric which equally account for all classes.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Before-Oversampling-2","page":"SMOTEN on Mushroom Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"y_pred = MLJ.predict(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"0.22","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#After-Oversampling-2","page":"SMOTEN on Mushroom Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"y_pred_over = MLJ.predict(mach_over, X_test)\n\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"0.4","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Evaluating-the-Model-Revisited","page":"SMOTEN on Mushroom Data","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a full blown 18% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#Before-Oversampling-3","page":"SMOTEN on Mushroom Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬───────────┬─────────────┬──────────┬────────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼───────────┼─────────────┼──────────┼────────────────────\n│ BalancedAccuracy( │ predict │ 0.218 │ 0.000718 │ [0.218, 0.218, 0. ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴───────────┴─────────────┴──────────┴────────────────────\n 1 column omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Before oversampling, and assuming that the balanced accuracy score is normally distribued we can be 95% confident that the balanced accuracy on new data is 21.8±0.07. This is a better estimate than the 20% figure we had earlier.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/#After-Oversampling-3","page":"SMOTEN on Mushroom Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.SMOTEN(k=2, ratios=ratios, rng=Random.Xoshiro(42))\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Progress: 22%|█████████▏ | ETA: 0:00:01\u001b[K\n\u001b[A\nProgress: 56%|██████████████████████▊ | ETA: 0:00:00\u001b[K\nProgress: 22%|█████████▏ | ETA: 0:00:00\u001b[K\n\u001b[A\nProgress: 78%|███████████████████████████████▉ | ETA: 0:00:00\u001b[K\n\u001b[A\n\n\ntrained Machine; does not cache data\n model: BalancedModelDeterministic(balancers = Imbalance.MLJ.SMOTEN{Dict{String, Float64}, Xoshiro}[SMOTEN(k = 2, …)], …)\n args: \n 1:\tSource @692 ⏎ Table{Union{AbstractVector{Multiclass{10}}, AbstractVector{Multiclass{12}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{1}}, AbstractVector{Multiclass{4}}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{9}}, AbstractVector{Multiclass{6}}, AbstractVector{Multiclass{7}}}}\n 2:\tSource @468 ⏎ AbstractVector{Multiclass{9}}","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":" y_pred_over == predict(mach_over, X_test)","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"true","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Now let's cross-validate","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"cv=CV(nfolds=10)\ne = evaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"e","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"PerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬───────────┬─────────────┬─────────┬─────────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼───────────┼─────────────┼─────────┼─────────────────────\n│ BalancedAccuracy( │ predict │ 0.4 │ 0.00483 │ [0.398, 0.405, 0.3 ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴───────────┴─────────────┴─────────┴─────────────────────\n 1 column omitted","category":"page"},{"location":"examples/smoten_mushroom/smoten_mushroom/","page":"SMOTEN on Mushroom Data","title":"SMOTEN on Mushroom Data","text":"Fair enough. After oversampling the interval under the same assumptions is 40±0.5%; this agrees with our earlier observations using simple point estimates; oversampling here approximately delivers a 18% improvement in balanced accuracy.","category":"page"},{"location":"algorithms/implementation_notes/#Generalizing-to-Multiclass","page":"Implementation Notes","title":"Generalizing to Multiclass","text":"","category":"section"},{"location":"algorithms/implementation_notes/","page":"Implementation Notes","title":"Implementation Notes","text":"Papers often propose the resampling algorithm for the case of binary classification only. In many cases, the algorithm only expects a set of points to resample and has nothing to do with the existence of a majority class (e.g., estimates the distribution of points then generates new samples from it) so it can be generalized by simply applying the algorithm on each class. In other cases, there is an interaction with the majority class (e.g., a point is borderline in BorderlineSMOTE1 if the majority but not all its neighbors are from the majority class). In this case, a one-vs-rest scheme is used as proposed in [1]. For instance, a point is now borderline if most but not all its neighbors are from a different class. ","category":"page"},{"location":"algorithms/implementation_notes/#Generalizing-to-Real-Ratios","page":"Implementation Notes","title":"Generalizing to Real Ratios","text":"","category":"section"},{"location":"algorithms/implementation_notes/","page":"Implementation Notes","title":"Implementation Notes","text":"Papers often proposes the resampling algorithm using integer ratios. For instance, a ratio of 2 would mean to double the amount of data in a class and a ratio of 22 is not allowed or will be rounded. In Imbalance.jl any appropriate real ratio can be used and the ratio is relative to the size of the majority or minority class depending on whether the algorithm is oversampling or undersampling. The generalization occurs by randomly choosing points instead of looping on each point. That is, if a 22 ratio corresponds to 227 examples then 227 examples are chosen randomly by replacement then applying resampling logic to each. Given an integer ratio k, this falls back to be on average equivalent to looping on the points k times.","category":"page"},{"location":"algorithms/implementation_notes/","page":"Implementation Notes","title":"Implementation Notes","text":"[1] Fernández, A., López, V., Galar, M., Del Jesus, M. J., and Herrera, F. (2013). Analysing the classifi- cation of imbalanced data-sets with multiple classes: Binarization techniques and ad-hoc approaches. Knowledge-Based Systems, 42:97–110.","category":"page"},{"location":"algorithms/mlj_balancing/#Combining-Resamplers","page":"Combination","title":"Combining Resamplers","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"Resampling methods can be combined sequentially or in parallel, along with a classification model, to yield hybrid or ensemble models that may be even more powerful than using the classification model with only one of the individual resamplers.","category":"page"},{"location":"algorithms/mlj_balancing/#Sequential-Resampling","page":"Combination","title":"Sequential Resampling","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"MLJBalancing.jl allows chaining an arbitrary number of resamplers from Imbalance.jl (also called balancers) with classification models from MLJ via BalancedModel. This makes it possible to use BalancedModel to form hybrid resampling methods that combine oversampling and under-sampling methods in a linear pipeline such as SMOTE-Tomek and SMOTE-ENN.","category":"page"},{"location":"algorithms/mlj_balancing/#Construct-the-resampler-and-classification-models","page":"Combination","title":"Construct the resampler and classification models","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"SMOTE = @load SMOTE pkg=Imbalance verbosity=0\nTomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n\noversampler = SMOTE(k=5, ratios=1.0, rng=42)\nundersampler = TomekUndersampler(min_ratios=0.5, rng=42)\n\nlogistic_model = LogisticClassifier()","category":"page"},{"location":"algorithms/mlj_balancing/#Wrap-them-all-in-BalancedModel","page":"Combination","title":"Wrap them all in BalancedModel","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"balanced_model = BalancedModel(model=logistic_model, \n balancer1=oversampler, balancer2=undersampler)","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"Here training data will be passed to balancer1 then balancer2, whose output is used to train the classifier model. In prediction, the resamplers balancer1 and blancer2 are bypassed and in general. At this point, they behave like one single MLJ model that can be fit, validated or fine-tuned like any other.","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"In general, there can be any number of balancers, and the user can give the balancers arbitrary names. ","category":"page"},{"location":"algorithms/mlj_balancing/#Parallel-Resampling-with-Balanced-Bagging","page":"Combination","title":"Parallel Resampling with Balanced Bagging","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"MLJBalancing.jl also offers an implementation of bagging over probabilistic classifiers where the majority class is randomly undersampled T times down to the size of the minority class then a model is trained on each of the T undersampled datasets. The predictions are then aggregated by averaging. This is offered via BalancedBaggingClassifier and can be only used for binary classification.","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"BalancedBaggingClassifier(model=nothing, T=0, rng = Random.default_rng(),)","category":"page"},{"location":"algorithms/mlj_balancing/#Arguments","page":"Combination","title":"Arguments","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"model::Probabilistic: A probabilistic classification model that implements the MLJModelInterface\nT::Integer=0: The number of bags to be used in the ensemble. If not given, will be set as the ratio between the frequency of the majority and minority classes.\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer ","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"seed to be used with Xoshiro","category":"page"},{"location":"algorithms/mlj_balancing/#Example","page":"Combination","title":"Example","text":"","category":"section"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"using MLJ\nusing Imbalance\nusing MLJBalancing\n\nX, y = generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], \n probs = [0.9, 0.1], \n type = \"ColTable\", \n rng=42)\n\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\nlogistic_model = LogisticClassifier()\nbagging_model = BalancedBaggingClassifier(model=logistic_model, T=10, rng=Random.Xoshiro(42))","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"Now you can fit, predict, cross-validate and finetune it like any other probabilistic MLJ model where X must be a table input (e.g., a dataframe).","category":"page"},{"location":"algorithms/mlj_balancing/","page":"Combination","title":"Combination","text":"mach = machine(bagging_model, X, y)\nfit!(mach)\npred = predict(mach, X)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#From-RandomOversampling-to-ROSE","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"\nimport Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\",\n \"ScientificTypes\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing ScientificTypes\nusing Imbalance\nusing Plots, Measures\nusing HTTP: download","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Loading-Data","page":"From RandomOversampling to ROSE","title":"Loading Data","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Let's load the Iris dataset, the objective of this dataset is to predict the type of flower as one of \"virginica\", \"versicolor\" and \"setosa\" using its sepal and petal length and width. ","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"We don't need to so from a CSV file this time because MLJ has a macro for loading it already! The only difference is that we will need to explictly convert it to a dataframe as MLJ loads it as a named tuple of vectors.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"X, y = @load_iris\nX = DataFrame(X)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"┌──────────────┬─────────────┬──────────────┬─────────────┐\n│ sepal_length │ sepal_width │ petal_length │ petal_width │\n│ Float64 │ Float64 │ Float64 │ Float64 │\n│ Continuous │ Continuous │ Continuous │ Continuous │\n├──────────────┼─────────────┼──────────────┼─────────────┤\n│ 5.1 │ 3.5 │ 1.4 │ 0.2 │\n│ 4.9 │ 3.0 │ 1.4 │ 0.2 │\n│ 4.7 │ 3.2 │ 1.3 │ 0.2 │\n│ 4.6 │ 3.1 │ 1.5 │ 0.2 │\n│ 5.0 │ 3.6 │ 1.4 │ 0.2 │\n└──────────────┴─────────────┴──────────────┴─────────────┘","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Our purpose for this tutorial is primarily visuallization. Thus, let's select two of the continuous features only to work with. It's known that the sepal length and width play a much bigger role in classifying the type of flower so let's keep those only.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"X = select(X, :petal_width, :petal_length)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"┌─────────────┬──────────────┐\n│ petal_width │ petal_length │\n│ Float64 │ Float64 │\n│ Continuous │ Continuous │\n├─────────────┼──────────────┤\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.3 │\n│ 0.2 │ 1.5 │\n│ 0.2 │ 1.4 │\n└─────────────┴──────────────┘","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Coercing-Data","page":"From RandomOversampling to ROSE","title":"Coercing Data","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"ScientificTypes.schema(X)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"┌──────────────┬────────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────┼────────────┼─────────┤\n│ petal_width │ Continuous │ Float64 │\n│ petal_length │ Continuous │ Float64 │\n└──────────────┴────────────┴─────────┘","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Things look good, no coercion is needed.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Oversampling","page":"From RandomOversampling to ROSE","title":"Oversampling","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Iris, by default has no imbalance problem","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"checkbalance(y)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"virginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"To simulate that there is a balance problem, we will consider a random sample of 100 observations. A random sample does not guarantee perserving the proportion of classes; in this, we actually set the seed to get a very unlikely random sample that suffers from strong imbalance.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Random.seed!(803429)\nsubset_indices = rand(1:size(X, 1), 100)\nX, y = X[subset_indices, :], y[subset_indices]\ncheckbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"versicolor: ▇▇▇▇▇▇▇▇▇▇▇ 12 (22.6%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 35 (66.0%) \nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"We will treat this as our training set going forward so we don't need to partition. Now let's oversample it with ROSE.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Xover, yover = rose(X, y; s=0.3, ratios=Dict(\"versicolor\" => 1.0, \"setosa\"=>1.0))\ncheckbalance(yover)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Progress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[K\n\u001b[A\n\nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Training-the-Model","page":"From RandomOversampling to ROSE","title":"Training the Model","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"53-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Let's go for a BayesianLDA.","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"import Pkg; Pkg.add(\"MultivariateStats\")","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Before-Oversampling","page":"From RandomOversampling to ROSE","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"# 1. Load the model\nBayesianLDA = @load BayesianLDA pkg=MultivariateStats\n\n# 2. Instantiate it \nmodel = BayesianLDA()\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X, y)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#After-Oversampling","page":"From RandomOversampling to ROSE","title":"After Oversampling","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Plot-Decision-Boundaries","page":"From RandomOversampling to ROSE","title":"Plot Decision Boundaries","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Construct ranges for each feature and consecutively a grid","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"petal_width_range =\n\trange(minimum(X.petal_width) - 1, maximum(X.petal_width) + 1, length = 200)\npetal_length_range =\n\trange(minimum(X.petal_length) - 1, maximum(X.petal_length) + 1, length = 200)\ngrid_points = [(pw, pl) for pw in petal_width_range, pl in petal_length_range]\n","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"200×200 Matrix{Tuple{Float64, Float64}}:\n (-0.9, 0.2) (-0.9, 0.238693) … (-0.9, 7.9)\n (-0.878894, 0.2) (-0.878894, 0.238693) (-0.878894, 7.9)\n (-0.857789, 0.2) (-0.857789, 0.238693) (-0.857789, 7.9)\n (-0.836683, 0.2) (-0.836683, 0.238693) (-0.836683, 7.9)\n (-0.815578, 0.2) (-0.815578, 0.238693) (-0.815578, 7.9)\n (-0.794472, 0.2) (-0.794472, 0.238693) … (-0.794472, 7.9)\n (-0.773367, 0.2) (-0.773367, 0.238693) (-0.773367, 7.9)\n (-0.752261, 0.2) (-0.752261, 0.238693) (-0.752261, 7.9)\n (-0.731156, 0.2) (-0.731156, 0.238693) (-0.731156, 7.9)\n (-0.71005, 0.2) (-0.71005, 0.238693) (-0.71005, 7.9)\n ⋮ ⋱ \n (3.13116, 0.2) (3.13116, 0.238693) (3.13116, 7.9)\n (3.15226, 0.2) (3.15226, 0.238693) (3.15226, 7.9)\n (3.17337, 0.2) (3.17337, 0.238693) (3.17337, 7.9)\n (3.19447, 0.2) (3.19447, 0.238693) (3.19447, 7.9)\n (3.21558, 0.2) (3.21558, 0.238693) … (3.21558, 7.9)\n (3.23668, 0.2) (3.23668, 0.238693) (3.23668, 7.9)\n (3.25779, 0.2) (3.25779, 0.238693) (3.25779, 7.9)\n (3.27889, 0.2) (3.27889, 0.238693) (3.27889, 7.9)\n (3.3, 0.2) (3.3, 0.238693) (3.3, 7.9)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Evaluate the grid with the machine before and after oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"grid_predictions = [\n\tpredict_mode(mach, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\tpoint in grid_points\n]\ngrid_predictions_over = [\n\tpredict_mode(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\tpoint in grid_points\n]","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"200×200 CategoricalArrays.CategoricalArray{String,2,UInt32}:\n \"setosa\" \"setosa\" \"setosa\" … \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" … \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n \"setosa\" \"setosa\" \"setosa\" \"versicolor\" \"versicolor\"\n ⋮ ⋱ \n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" … \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"\n \"virginica\" \"virginica\" \"virginica\" \"virginica\" \"virginica\"","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Make two contour plots using the grid predictions before and after oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"p = contourf(petal_length_range, petal_width_range, grid_predictions,\n\tlevels = 3, color = :Set3_3, colorbar = false)\np_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n\tlevels = 3, color = :Set3_3, colorbar = false)\nprintln()","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"Scatter plot the data before and after oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"old_count = size(X, 1)\nlabels = unique(y)\ncolors = Dict(\"setosa\" => \"green\", \"versicolor\" => \"yellow\",\n\t\"virginica\" => \"purple\")\n\nfor label in labels\n\tscatter!(p, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"Before Oversampling\")\n\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"After Oversampling\")\n\t# find new points only and plot with different shape\n\tscatter!(p_over, Xover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\tcolor = colors[label], label = label*\"-over\", markershape = :diamond,\n\t\ttitle = \"After Oversampling\")\nend\n\nplot_res = plot(\n\tp,\n\tp_over,\n\tlayout = (1, 2),\n\txlabel = \"petal length\",\n\tylabel = \"petal width\",\n\tsize = (900, 300),\n\tmargin = 5mm, dpi = 200\n)\nsavefig(plot_res, \"./assets/ROSE-before-after.png\")","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"(Image: Before After ROSE)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/#Effect-of-Increasing-s","page":"From RandomOversampling to ROSE","title":"Effect of Increasing s","text":"","category":"section"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"anim = @animate for s ∈ 0:0.03:6.0\n\t# oversample\n\tXover, yover =\n\t\trose(X, y; s = s, ratios = Dict(\"setosa\" => 1.0, \"versicolor\" => 1.0), rng = 42)\n\n\tmodel = BayesianLDA()\n\tmach_over = machine(model, Xover, yover)\n\tfit!(mach_over, verbosity = 0)\n\n\t# grid predictions\n\tgrid_predictions_over = [\n\t\tpredict_mode(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\t\tpoint in grid_points\n\t]\n\n\tp_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n\t\tlevels = 3, color = :Set3_3, colorbar = false)\n\n\told_count = size(X, 1)\n\tfor label in labels\n\t\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\t\tcolor = colors[label], label = label,\n\t\t\ttitle = \"Oversampling with s=$s\")\n\t\t# find new points only and plot with different shape\n\t\tscatter!(p_over,\n\t\t\tXover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tcolor = colors[label], label = label * \"-over\", markershape = :diamond,\n\t\t\ttitle = \"Oversampling with s=$s\")\n\tend\n\tplot!(dpi = 150)\nend\n","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"gif(anim, \"./assets/rose-animation.gif\", fps=6)\nprintln()","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"(Image: ROSE Effect of S)","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"As we can see, the larger s is the more spread out are the oversampled points. This is expected because what ROSE does is oversample by sampling from the distribution that corresponds to placing Gaussians on the existing points and s is a hyperparameter proportional to the bandwidth of the Gaussians. When s=0 the only points that can be generated lie on top of others; i.e., ROSE becomes equivalent to random oversampling","category":"page"},{"location":"examples/effect_of_s/effect_of_s/","page":"From RandomOversampling to ROSE","title":"From RandomOversampling to ROSE","text":"The decision boundary is mainly unstable because we used a small number of epochs with the perceptron to generate this animation. It still took plenty of time.","category":"page"},{"location":"about/#Credits","page":"About","title":"Credits","text":"","category":"section"},{"location":"about/","page":"About","title":"About","text":"This package was created by Essam Wisam as a Google Summer of Code project, under the mentorship of Anthony Blaom. Special thanks also go to Rik Huijzer for his friendliness and the binary SMOTE implementation in Resample.jl.","category":"page"},{"location":"algorithms/undersampling_algorithms/#Undersampling-Algorithms","page":"Undersampling","title":"Undersampling Algorithms","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"The following table portrays the supported undersampling algorithms, whether the mechanism deletes or generates new data and the supported types of data.","category":"page"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"Undersampling Method Mechanism Supported Data Types\nRandom Undersampler Delete existing data as needed Continuous and/or nominal\nCluster Undersampler Generate new data or delete existing data Continuous\nEdited Nearest Neighbors Undersampler Delete existing data meeting certain conditions (cleaning) Continuous\nTomek Links Undersampler Delete existing data meeting certain conditions (cleaning) Continuous","category":"page"},{"location":"algorithms/undersampling_algorithms/#Random-Undersampler","page":"Undersampling","title":"Random Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"random_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.random_undersample","page":"Undersampling","title":"Imbalance.random_undersample","text":"random_undersample(\n X, y; \n ratios=1.0, rng=default_rng(), \n try_preserve_type=true\n)\n\nDescription\n\nNaively undersample a dataset by randomly deleting existing observations.\n\nPositional Arguments\n\nX: A matrix of real numbers or a table with element scitypes that subtype Union{Finite, Infinite}. Elements in nominal columns should subtype Finite (i.e., have scitype OrderedFactor or Multiclass) and elements in continuous columns should subtype Infinite (i.e., have scitype Count or Continuous).\ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nratios=1.0: A parameter that controls the amount of undersampling to be done for each class\nCan be a float and in this case each class will be undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n 1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n 2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n 0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply randomundersampling\nX_under, y_under = random_undersample(X, y; ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), \n rng=42)\n \njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the RandomUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nRandomUndersampler = @load RandomUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = RandomUndersampler(ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), \n rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate Random Undersampler model\nundersampler = RandomUndersampler(y_ind; ratios=Dict(0=>1.0, 1=>1.0, 2=>1.0), rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/undersampling_algorithms/#Cluster-Undersampler","page":"Undersampling","title":"Cluster Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"cluster_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.cluster_undersample","page":"Undersampling","title":"Imbalance.cluster_undersample","text":"cluster_undersample(\n X, y; \n mode= \"nearest\", ratios = 1.0, maxiter = 100,\n rng=default_rng(), try_preserve_type=true\n)\n\nDescription\n\nUndersample a dataset using clustering undersampling as presented in [1] using K-means as the clustering algorithm.\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nmode::AbstractString=\"nearest: If \"center\" then the undersampled data will consist of the centriods of each cluster found; meanwhile, if \"nearest\" then it will consist of the nearest neighbor of each centroid.\nratios=1.0: A parameter that controls the amount of undersampling to be done for each class\nCan be a float and in this case each class will be undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float ratio for that class\n\nmaxiter::Integer=100: Maximum number of iterations to run K-means\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n class_probs, rng=42) \n \njulia> Imbalance.checkbalance(y; ref=\"minority\")\n 1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n 2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n 0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply cluster_undersampling\nX_under, y_under = cluster_undersample(X, y; mode=\"nearest\", \n ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), rng=42)\n \njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the ClusterUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nClusterUndersampler = @load ClusterUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = ClusterUndersampler(mode=\"nearest\", \n ratios=Dict(0=>1.0, 1=> 1.0, 2=>1.0), rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate ClusterUndersampler model\nundersampler = ClusterUndersampler(y_ind; mode=\"nearest\", \n ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Wei-Chao, L., Chih-Fong, T., Ya-Han, H., & Jing-Shang, J. (2017). Clustering-based undersampling in class-imbalanced data. Information Sciences, 409–410, 17–26.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/undersampling_algorithms/#Edited-Nearest-Neighbors-Undersampler","page":"Undersampling","title":"Edited Nearest Neighbors Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"enn_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.enn_undersample","page":"Undersampling","title":"Imbalance.enn_undersample","text":"enn_undersample(\n X, y; k = 5, keep_condition = \"mode\",\n min_ratios = 1.0, force_min_ratios = false,\n rng = default_rng(), try_preserve_type=true\n)\n\nDescription\n\nUndersample a dataset by removing points that violate a certain condition such as belonging to a different class compared to the majority of the neighbors, as proposed in [1].\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nk::Integer=5: Number of nearest neighbors to consider in the algorithm. Should be within the range 0 < k < n where n is the number of observations in the data. It will be automatically set to n-1 if n ≤ k.\n\nkeep_condition::AbstractString=\"mode\": The condition that leads to removing a point upon violation. Takes one of \"exists\", \"mode\", \"only mode\" and \"all\"\n\"exists\": the point has at least one neighbor from the same class\n\"mode\": the class of the point is one of the most frequent classes of the neighbors (there may be many)\n\"only mode\": the class of the point is the single most frequent class of the neighbors\n\"all\": the class of the point is the same as all the neighbors\n\nmin_ratios=1.0: A parameter that controls the maximum amount of undersampling to be done for each class. If this algorithm cleans the data to an extent that this is violated, some of the cleaned points will be revived randomly so that it is satisfied.\nCan be a float and in this case each class will be at most undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float minimum ratio for that class\n\nforce_min_ratios=false: If true, and this algorithm cleans the data such that the ratios for each class exceed those specified in min_ratios then further undersampling will be perform so that the final ratios are equal to min_ratios.\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42)\n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply enn undersampling\nX_under, y_under = enn_undersample(X, y; k=3, keep_condition=\"only mode\", \n min_ratios=0.5, rng=42)\n\njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 10 (100.0%) \n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 10 (100.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 24 (240.0%) \n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the ENNUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nENNUndersampler = @load ENNUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = ENNUndersampler(k=3, keep_condition=\"only mode\", min_ratios=0.5, rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42)\n\n# Initiate ENN Undersampler model\nundersampler = ENNUndersampler(y_ind; k=3, keep_condition=\"only mode\", rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Dennis L Wilson. Asymptotic properties of nearest neighbor rules using edited data. \tIEEE Transactions on Systems, Man, and Cybernetics, pages 408–421, 1972.\n\n\n\n\n\n","category":"function"},{"location":"algorithms/undersampling_algorithms/#Tomek-Links-Undersampler","page":"Undersampling","title":"Tomek Links Undersampler","text":"","category":"section"},{"location":"algorithms/undersampling_algorithms/","page":"Undersampling","title":"Undersampling","text":"tomek_undersample","category":"page"},{"location":"algorithms/undersampling_algorithms/#Imbalance.tomek_undersample","page":"Undersampling","title":"Imbalance.tomek_undersample","text":"tomek_undersample(\n X, y;\n min_ratios = 1.0, force_min_ratios = false,\n rng = default_rng(), try_preserve_type=true\n)\n\nDescription\n\nUndersample a dataset by removing (\"cleaning\") any point that is part of a tomek link in the data. \tTomek links are presented in [1].\n\nPositional Arguments\n\nX: A matrix or table of floats where each row is an observation from the dataset \ny: An abstract vector of labels (e.g., strings) that correspond to the observations in X\n\nKeyword Arguments\n\nmin_ratios=1.0: A parameter that controls the maximum amount of undersampling to be done for each class. If this algorithm cleans the data to an extent that this is violated, some of the cleaned points will be revived randomly so that it is satisfied.\nCan be a float and in this case each class will be at most undersampled to the size of the minority class times the float. By default, all classes are undersampled to the size of the minority class\nCan be a dictionary mapping each class label to the float minimum ratio for that class\n\nforce_min_ratios=false: If true, and this algorithm cleans the data such that the ratios for each class exceed those specified in min_ratios then further undersampling will be perform so that the final ratios are equal to min_ratios.\n\nrng::Union{AbstractRNG, Integer}=default_rng(): Either an AbstractRNG object or an Integer seed to be used with Xoshiro if the Julia VERSION supports it. Otherwise, uses MersenneTwister`.\n\ntry_preserve_type::Bool=true: When true, the function will try to not change the type of the input table (e.g., DataFrame). However, for some tables, this may not succeed, and in this case, the table returned will be a column table (named-tuple of vectors). This parameter is ignored if the input is a matrix.\n\nReturns\n\nX_under: A matrix or table that includes the data after undersampling depending on whether the input X is a matrix or table respectively\ny_under: An abstract vector of labels corresponding to X_under\n\nExample\n\nusing Imbalance\n\n# set probability of each class\nclass_probs = [0.5, 0.2, 0.3] \nnum_rows, num_continuous_feats = 100, 5\n# generate a table and categorical vector accordingly\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42) \n\njulia> Imbalance.checkbalance(y; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 33 (173.7%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 48 (252.6%) \n\n# apply tomek undersampling\nX_under, y_under = tomek_undersample(X, y; min_ratios=1.0, rng=42)\n\njulia> Imbalance.checkbalance(y_under; ref=\"minority\")\n1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 19 (100.0%) \n2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 22 (115.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 36 (189.5%)\n\nMLJ Model Interface\n\nSimply pass the keyword arguments while initiating the TomekUndersampler model and pass the positional arguments X, y to the transform method. \n\nusing MLJ\nTomekUndersampler = @load TomekUndersampler pkg=Imbalance\n\n# Wrap the model in a machine\nundersampler = TomekUndersampler(min_ratios=1.0, rng=42)\nmach = machine(undersampler)\n\n# Provide the data to transform (there is nothing to fit)\nX_under, y_under = transform(mach, X, y)\n\nYou can read more about this MLJ interface by accessing it from MLJ's model browser.\n\nTableTransforms Interface\n\nThis interface assumes that the input is one table Xy and that y is one of the columns. Hence, an integer y_ind must be specified to the constructor to specify which column y is followed by other keyword arguments. Only Xy is provided while applying the transform.\n\nusing Imbalance\nusing Imbalance.TableTransforms\n\n# Generate imbalanced data\nnum_rows = 100\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n min_sep=0.01, stds=[3.0 3.0 3.0], class_probs, rng=42) \n\n# Initiate TomekUndersampler model\nundersampler = TomekUndersampler(y_ind; min_ratios=1.0, rng=42)\nXy_under = Xy |> undersampler \nXy_under, cache = TableTransforms.apply(undersampler, Xy) # equivalently\n\nThe reapply(undersampler, Xy, cache) method from TableTransforms simply falls back to apply(undersample, Xy) and the revert(undersampler, Xy, cache) is not supported.\n\nIllustration\n\nA full basic example along with an animation can be found here. You may find more practical examples in the tutorial section which also explains running code on Google Colab.\n\nReferences\n\n[1] Ivan Tomek. Two modifications of cnn. IEEE Trans. Systems, Man and Cybernetics, 6:769–772, 1976.\n\n\n\n\n\n","category":"function"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Effect-of-ratios-Hyperparameter","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"using Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing ScientificTypes\nusing Imbalance\nusing Plots, Measures","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Loading-Data","page":"Effect of ratios Hyperparameter","title":"Loading Data","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Let's load the Iris dataset, the objective of this dataset is to predict the type of flower as one of \"virginica\", \"versicolor\" and \"setosa\" using its sepal and petal length and width.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"We don't need to so from a CSV file this time because MLJ has a macro for loading it already! The only difference is that we will need to explictly convert it to a dataframe as MLJ loads it as a named tuple of vectors.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"X, y = @load_iris\nX = DataFrame(X)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌──────────────┬─────────────┬──────────────┬─────────────┐\n│ sepal_length │ sepal_width │ petal_length │ petal_width │\n│ Float64 │ Float64 │ Float64 │ Float64 │\n│ Continuous │ Continuous │ Continuous │ Continuous │\n├──────────────┼─────────────┼──────────────┼─────────────┤\n│ 5.1 │ 3.5 │ 1.4 │ 0.2 │\n│ 4.9 │ 3.0 │ 1.4 │ 0.2 │\n│ 4.7 │ 3.2 │ 1.3 │ 0.2 │\n│ 4.6 │ 3.1 │ 1.5 │ 0.2 │\n│ 5.0 │ 3.6 │ 1.4 │ 0.2 │\n└──────────────┴─────────────┴──────────────┴─────────────┘","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Our purpose for this tutorial is primarily visuallization. Thus, let's select two of the continuous features only to work with. It's known that the sepal length and width play a much bigger role in classifying the type of flower so let's keep those only.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"X = select(X, :petal_width, :petal_length)\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌─────────────┬──────────────┐\n│ petal_width │ petal_length │\n│ Float64 │ Float64 │\n│ Continuous │ Continuous │\n├─────────────┼──────────────┤\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.4 │\n│ 0.2 │ 1.3 │\n│ 0.2 │ 1.5 │\n│ 0.2 │ 1.4 │\n└─────────────┴──────────────┘","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Coercing-Data","page":"Effect of ratios Hyperparameter","title":"Coercing Data","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"ScientificTypes.schema(X)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌──────────────┬────────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────┼────────────┼─────────┤\n│ petal_width │ Continuous │ Float64 │\n│ petal_length │ Continuous │ Float64 │\n└──────────────┴────────────┴─────────┘","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Things look good, no coercion is needed.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Oversampling","page":"Effect of ratios Hyperparameter","title":"Oversampling","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Iris, by default has no imbalance problem","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"checkbalance(y)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"virginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 50 (100.0%)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"To simulate that there is a balance problem, we will consider a random sample of 100 observations. A random sample does not guarantee perserving the proportion of classes; in this, we actually set the seed to get a very unlikely random sample that suffers from moderate imbalance.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Random.seed!(803429)\nsubset_indices = rand(1:size(X, 1), 100)\nX, y = X[subset_indices, :], y[subset_indices]\ncheckbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"versicolor: ▇▇▇▇▇▇▇▇▇▇▇ 12 (22.6%) \nsetosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 35 (66.0%) \nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"We will treat this as our training set going forward so we don't need to partition. Now let's oversample it with SMOTE.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Xover, yover = smote(X, y; k=5, ratios=Dict(\"versicolor\" => 0.7), rng=42)\ncheckbalance(yover)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"setosa: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 35 (66.0%) \nversicolor: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 37 (69.8%) \nvirginica: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 53 (100.0%)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Training-the-Model","page":"Effect of ratios Hyperparameter","title":"Training the Model","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"53-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Let's go for an SVM","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"import Pkg;\nPkg.add(\"MLJLIBSVMInterface\");","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Before-Oversampling","page":"Effect of ratios Hyperparameter","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"# 1. Load the model\nSVC = @load SVC pkg = LIBSVM\n\n# 2. Instantiate it (γ=0.01 is intentional)\nmodel = SVC(gamma=0.01)\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X, y)\n\n# 4. fit the machine learning model\nfit!(mach)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\nimport MLJLIBSVMInterface ✔\n\n\n┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @527 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @580 ⏎ AbstractVector{Multiclass{3}}","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#After-Oversampling","page":"Effect of ratios Hyperparameter","title":"After Oversampling","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"┌ Info: Training machine(SVC(kernel = RadialBasis, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: SVC(kernel = RadialBasis, …)\n args: \n 1:\tSource @277 ⏎ Table{AbstractVector{Continuous}}\n 2:\tSource @977 ⏎ AbstractVector{Multiclass{3}}","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Plot-Decision-Boundaries","page":"Effect of ratios Hyperparameter","title":"Plot Decision Boundaries","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Construct ranges for each feature and consecutively a grid","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"petal_width_range =\n\trange(minimum(X.petal_width) - 1, maximum(X.petal_width) + 1, length = 200)\npetal_length_range =\n\trange(minimum(X.petal_length) - 1, maximum(X.petal_length) + 1, length = 200)\ngrid_points = [(pw, pl) for pw in petal_width_range, pl in petal_length_range]","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"200×200 Matrix{Tuple{Float64, Float64}}:\n (-0.9, 0.2) (-0.9, 0.238693) … (-0.9, 7.9)\n (-0.878894, 0.2) (-0.878894, 0.238693) (-0.878894, 7.9)\n (-0.857789, 0.2) (-0.857789, 0.238693) (-0.857789, 7.9)\n (-0.836683, 0.2) (-0.836683, 0.238693) (-0.836683, 7.9)\n (-0.815578, 0.2) (-0.815578, 0.238693) (-0.815578, 7.9)\n (-0.794472, 0.2) (-0.794472, 0.238693) … (-0.794472, 7.9)\n (-0.773367, 0.2) (-0.773367, 0.238693) (-0.773367, 7.9)\n (-0.752261, 0.2) (-0.752261, 0.238693) (-0.752261, 7.9)\n (-0.731156, 0.2) (-0.731156, 0.238693) (-0.731156, 7.9)\n (-0.71005, 0.2) (-0.71005, 0.238693) (-0.71005, 7.9)\n ⋮ ⋱ \n (3.13116, 0.2) (3.13116, 0.238693) (3.13116, 7.9)\n (3.15226, 0.2) (3.15226, 0.238693) (3.15226, 7.9)\n (3.17337, 0.2) (3.17337, 0.238693) (3.17337, 7.9)\n (3.19447, 0.2) (3.19447, 0.238693) (3.19447, 7.9)\n (3.21558, 0.2) (3.21558, 0.238693) … (3.21558, 7.9)\n (3.23668, 0.2) (3.23668, 0.238693) (3.23668, 7.9)\n (3.25779, 0.2) (3.25779, 0.238693) (3.25779, 7.9)\n (3.27889, 0.2) (3.27889, 0.238693) (3.27889, 7.9)\n (3.3, 0.2) (3.3, 0.238693) (3.3, 7.9)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Evaluate the grid with the machine before and after oversampling","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"grid_predictions =[\n predict(mach, Tables.table(reshape(collect(point), 1, 2)))[1] for\n \tpoint in grid_points\n ]\ngrid_predictions_over = [\n predict(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n point in grid_points\n]","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"200×200 CategoricalArrays.CategoricalArray{String,2,UInt32}:\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" … \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" … \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n ⋮ ⋱ \n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" … \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"\n \"setosa\" \"setosa\" \"setosa\" \"setosa\" \"virginica\" \"virginica\"","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Make two contour plots using the grid predictions before and after oversampling","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"p = contourf(petal_length_range, petal_width_range, grid_predictions,\n levels=3, color=:Set3_3, colorbar=false)\np_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n levels=3, color=:Set3_3, colorbar=false)\nprintln()","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Scatter plot the data before and after oversampling","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"old_count = size(X, 1)\n\ncolors = Dict(\"setosa\" => \"green\", \"versicolor\" => \"yellow\",\n\t\"virginica\" => \"purple\")\nlabels = unique(y)\nfor label in labels\n\tscatter!(p, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"Before Oversampling\")\n\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\tcolor = colors[label], label = label,\n\t\ttitle = \"After Oversampling\")\n\t# find new points only and plot with different shape\n\tscatter!(p_over, Xover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\tcolor = colors[label], label = label*\"-over\", markershape = :hexagon,\n\t\ttitle = \"After Oversampling\")\nend\n\nplot_res = plot(p, p_over, layout = (1, 2), xlabel = \"petal length\",\n\tylabel = \"petal width\", size = (900, 300), margin = 5mm, dpi = 200)\nsavefig(plot_res, \"./assets/before-after-smote.png\")\nprintln()","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"(Image: Before After SMOTE)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Notice how the minority class was completely ignore prior to oversampling. Not all models and hyperparameter settings are this delicate to class imbalance.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/#Effect-of-Ratios-Hyperparameter","page":"Effect of ratios Hyperparameter","title":"Effect of Ratios Hyperparameter","text":"","category":"section"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Now let's study the effect of the ratios hyperparameter. We will do this through an animated plot.","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"anim = @animate for versicolor_ratio ∈ 0.3:0.01:2\n\t# oversample\n\tXover, yover =\n\t\tsmote(X, y; k = 5, ratios = Dict(\"versicolor\" => versicolor_ratio), rng = 42)\n\n\t# fit machine\n\tmodel = SVC(gamma = 0.01)\n\tmach_over = machine(model, Xover, yover)\n\tfit!(mach_over, verbosity = 0)\n\n\t# grid predictions\n\tgrid_predictions_over = [\n\t\tpredict(mach_over, Tables.table(reshape(collect(point), 1, 2)))[1] for\n\t\tpoint in grid_points\n\t]\n\t# plot\n\tp_over = contourf(petal_length_range, petal_width_range, grid_predictions_over,\n\t\tlevels = 3, color = :Set3_3, colorbar = false)\n\told_count = size(X, 1)\n\tfor label in labels\n\t\tscatter!(p_over, X.petal_length[y.==label], X.petal_width[y.==label],\n\t\t\tcolor = colors[label], label = label,\n\t\t\ttitle = \"Oversampling versicolor with ratio $versicolor_ratio\")\n\t\t# find new points only and plot with different shape\n\t\tscatter!(p_over,\n\t\t\tXover.petal_length[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tXover.petal_width[old_count+1:end][yover[old_count+1:end].==label],\n\t\t\tcolor = colors[label], label = label * \"-over\", markershape = :hexagon,\n\t\t\ttitle = \"Oversampling versicolor with ratio $versicolor_ratio\")\n\tend\n\tplot!(dpi = 150)\nend\n","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"gif(anim, \"./assets/smote-animation.gif\", fps=6)\nprintln()","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"(Image: Ratios Parameter Effect)","category":"page"},{"location":"examples/effect_of_ratios/effect_of_ratios/","page":"Effect of ratios Hyperparameter","title":"Effect of ratios Hyperparameter","text":"Notice how setting ratios greedily can lead to overfitting.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#SMOTE-on-Customer-Churn-Data","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"\nimport Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Imbalance\nusing MLJBalancing\nusing CSV\nusing DataFrames\nusing ScientificTypes\nusing CategoricalArrays\nusing MLJ\nusing Plots\nusing Random\nusing HTTP: download","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Loading-Data","page":"SMOTE on Customer Churn Data","title":"Loading Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"In this example, we will consider the Churn for Bank Customers found on Kaggle where the objective is to predict whether a customer is likely to leave a bank given financial and demographic features.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/smote_churn_dataset/churn.csv\", \"./\")\ndf = CSV.read(\"./churn.csv\", DataFrame)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌───────────┬────────────┬──────────┬─────────────┬───────────┬─────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ RowNumber │ CustomerId │ Surname │ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ Int64 │ String31 │ Int64 │ String7 │ String7 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Count │ Textual │ Count │ Textual │ Textual │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├───────────┼────────────┼──────────┼─────────────┼───────────┼─────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 1 │ 15634602 │ Hargrave │ 619 │ France │ Female │ 42 │ 2 │ 0.0 │ 1 │ 1 │ 1 │ 1.01349e5 │ 1 │\n│ 2 │ 15647311 │ Hill │ 608 │ Spain │ Female │ 41 │ 1 │ 83807.9 │ 1 │ 0 │ 1 │ 1.12543e5 │ 0 │\n│ 3 │ 15619304 │ Onio │ 502 │ France │ Female │ 42 │ 8 │ 1.59661e5 │ 3 │ 1 │ 0 │ 1.13932e5 │ 1 │\n│ 4 │ 15701354 │ Boni │ 699 │ France │ Female │ 39 │ 1 │ 0.0 │ 2 │ 0 │ 0 │ 93826.6 │ 0 │\n│ 5 │ 15737888 │ Mitchell │ 850 │ Spain │ Female │ 43 │ 2 │ 1.25511e5 │ 1 │ 1 │ 1 │ 79084.1 │ 0 │\n└───────────┴────────────┴──────────┴─────────────┴───────────┴─────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"There are plenty of useless columns that we can get rid of such as RowNumber and CustomerID. We also have to get rid of the categoircal features because SMOTE won't be able to deal with those; however, other variants such as SMOTE-NC can which we will consider in another tutorial.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"df = df[:, Not([:RowNumber, :CustomerId, :Surname, \n :Geography, :Gender])]\n\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌─────────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ CreditScore │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├─────────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 619.0 │ 42.0 │ 2.0 │ 0.0 │ 1.0 │ 1.0 │ 1.0 │ 1.01349e5 │ 1.0 │\n│ 608.0 │ 41.0 │ 1.0 │ 83807.9 │ 1.0 │ 0.0 │ 1.0 │ 1.12543e5 │ 0.0 │\n│ 502.0 │ 42.0 │ 8.0 │ 1.59661e5 │ 3.0 │ 1.0 │ 0.0 │ 1.13932e5 │ 1.0 │\n│ 699.0 │ 39.0 │ 1.0 │ 0.0 │ 2.0 │ 0.0 │ 0.0 │ 93826.6 │ 0.0 │\n│ 850.0 │ 43.0 │ 2.0 │ 1.25511e5 │ 1.0 │ 1.0 │ 1.0 │ 79084.1 │ 0.0 │\n└─────────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Ideally, we may even remove ordinal variables because SMOTE will treat them as continuous and the synthetic data it generates will taking floating point values which will not occur in future data. Some models may be robust to this whatsoever and the main purpose of this tutorial is to later compare SMOTE-NC with SMOTE.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Coercing-Data","page":"SMOTE on Customer Churn Data","title":"Coercing Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Let's coerce everything to continuous except for the target variable.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"df = coerce(df, :Age=>Continuous,\n :Tenure=>Continuous,\n :Balance=>Continuous,\n :NumOfProducts=>Continuous,\n :HasCrCard=>Continuous,\n :IsActiveMember=>Continuous,\n :EstimatedSalary=>Continuous,\n :Exited=>Multiclass)\n\nScientificTypes.schema(df)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌─────────────────┬───────────────┬─────────────────────────────────┐\n│ names │ scitypes │ types │\n├─────────────────┼───────────────┼─────────────────────────────────┤\n│ CreditScore │ Count │ Int64 │\n│ Age │ Continuous │ Float64 │\n│ Tenure │ Continuous │ Float64 │\n│ Balance │ Continuous │ Float64 │\n│ NumOfProducts │ Continuous │ Float64 │\n│ HasCrCard │ Continuous │ Float64 │\n│ IsActiveMember │ Continuous │ Float64 │\n│ EstimatedSalary │ Continuous │ Float64 │\n│ Exited │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n└─────────────────┴───────────────┴─────────────────────────────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Unpacking-and-Splitting-Data","page":"SMOTE on Customer Churn Data","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"y, X = unpack(df, ==(:Exited); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"┌─────────────┬────────────┬────────────┬────────────┬───────────────┬────────────┬────────────────┬─────────────────┐\n│ CreditScore │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │\n│ Int64 │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │\n│ Count │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │ Continuous │\n├─────────────┼────────────┼────────────┼────────────┼───────────────┼────────────┼────────────────┼─────────────────┤\n│ 669.0 │ 31.0 │ 6.0 │ 1.13001e5 │ 1.0 │ 1.0 │ 0.0 │ 40467.8 │\n│ 822.0 │ 37.0 │ 3.0 │ 105563.0 │ 1.0 │ 1.0 │ 0.0 │ 1.82625e5 │\n│ 423.0 │ 36.0 │ 5.0 │ 97665.6 │ 1.0 │ 1.0 │ 0.0 │ 1.18373e5 │\n│ 623.0 │ 21.0 │ 10.0 │ 0.0 │ 2.0 │ 0.0 │ 1.0 │ 1.35851e5 │\n│ 691.0 │ 37.0 │ 7.0 │ 1.23068e5 │ 1.0 │ 1.0 │ 1.0 │ 98162.4 │\n└─────────────┴────────────┴────────────┴────────────┴───────────────┴────────────┴────────────────┴─────────────────┘","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Splitting the data into train and test portions is also easy using MLJ's partition function.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"train_inds, test_inds = partition(eachindex(y), 0.8, shuffle=true, rng=Random.Xoshiro(42))\nX_train, X_test = X[train_inds, :], X[test_inds, :]\ny_train, y_test = y[train_inds], y[test_inds]","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"(CategoricalValue{Int64, UInt32}[0, 1, 1, 0, 0, 0, 0, 0, 0, 0 … 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 1, 1, 0, 0, 0 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Oversampling","page":"SMOTE on Customer Churn Data","title":"Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 2037 (25.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 7963 (100.0%)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Looks like we have a class imbalance problem. Let's oversample with SMOTE and set the desired ratios so that the positive minority class is 90% of the majority class","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Xover, yover = smote(X_train, y_train; k=3, ratios=Dict(1=>0.9), rng=42)\ncheckbalance(yover)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 5736 (90.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 6373 (100.0%)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Training-the-Model","page":"SMOTE on Customer Churn Data","title":"Training the Model","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"models(matching(Xover, yover))","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"54-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = AdaBoostClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = AdaBoostStumpClassifier, package_name = DecisionTree, ... )\n (name = BaggingClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianLDA, package_name = MultivariateStats, ... )\n (name = BayesianQDA, package_name = MLJScikitLearnInterface, ... )\n (name = BayesianSubspaceLDA, package_name = MultivariateStats, ... )\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n ⋮\n (name = SGDClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVC, package_name = LIBSVM, ... )\n (name = SVMClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMLinearClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = SVMNuClassifier, package_name = MLJScikitLearnInterface, ... )\n (name = StableForestClassifier, package_name = SIRUS, ... )\n (name = StableRulesClassifier, package_name = SIRUS, ... )\n (name = SubspaceLDA, package_name = MultivariateStats, ... )\n (name = XGBoostClassifier, package_name = XGBoost, ... )","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Let's go for a logistic classifier form MLJLinearModels","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"import Pkg; Pkg.add(\"MLJLinearModels\")","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Before-Oversampling","page":"SMOTE on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"# 1. Load the model\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n\n# 2. Instantiate it\nmodel = LogisticClassifier()\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"trained Machine; caches model-specific representations of data\n model: LogisticClassifier(lambda = 2.220446049250313e-16, …)\n args: \n 1:\tSource @113 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}\n 2:\tSource @972 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#After-Oversampling","page":"SMOTE on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Evaluating-the-Model","page":"SMOTE on Customer Churn Data","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"To evaluate the model, we will use the balanced accuracy metric which equally account for all classes. ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Before-Oversampling-2","page":"SMOTE on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"y_pred = predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"0.5","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#After-Oversampling-2","page":"SMOTE on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"y_pred_over = predict_mode(mach_over, X_test)\n\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"0.57","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Evaluating-the-Model-Revisited","page":"SMOTE on Customer Churn Data","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a 7% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#Before-Oversampling-3","page":"SMOTE on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬──────────┬─────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼──────────┼─────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.5 │ 3.29e-16 │ [0.5, 0.5, 0.5 ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴──────────┴─────────────────\n 1 column omitted","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"This looks good. Negligble standard deviation; point estimates are all centered around 0.5.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/#After-Oversampling-3","page":"SMOTE on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.SMOTE(k=3, ratios=Dict(1=>0.9), rng=42)\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = LogisticClassifier(lambda = 2.220446049250313e-16, …), …)\n args: \n 1:\tSource @991 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}\n 2:\tSource @939 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"predict_mode(mach_over, X_test) == y_pred_over","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"true","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Now let's cross-validate","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:00:00\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.552 │ 0.0145 │ [0.549, 0.563, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/smote_churn_dataset/smote_churn_dataset/","page":"SMOTE on Customer Churn Data","title":"SMOTE on Customer Churn Data","text":"The improvement is about 5.2% after cross-validation. If we are further to assume scores to be normally distributed, then the 95% confidence interval is 5.2±1.45% improvement. Let's see if this gets any better when we rather use SMOTE-NC in a later example.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#SMOTENC-on-Customer-Churn-Data","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Imbalance\nusing MLJBalancing\nusing CSV\nusing DataFrames\nusing ScientificTypes\nusing CategoricalArrays\nusing MLJ\nusing Plots\nusing Random\nusing HTTP: download","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Loading-Data","page":"SMOTENC on Customer Churn Data","title":"Loading Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"In this example, we will consider the Churn for Bank Customers found on Kaggle where the objective is to predict whether a customer is likely to leave a bank given financial and demographic features. ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"We already considered this dataset using SMOTE, in this example we see if the results are any better using SMOTE-NC.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/smotenc_churn_dataset/churn.csv\", \"./\")\ndf = CSV.read(\"./churn.csv\", DataFrame)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌───────────┬────────────┬──────────┬─────────────┬───────────┬─────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ RowNumber │ CustomerId │ Surname │ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ Int64 │ String31 │ Int64 │ String7 │ String7 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Count │ Textual │ Count │ Textual │ Textual │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├───────────┼────────────┼──────────┼─────────────┼───────────┼─────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 1 │ 15634602 │ Hargrave │ 619 │ France │ Female │ 42 │ 2 │ 0.0 │ 1 │ 1 │ 1 │ 1.01349e5 │ 1 │\n│ 2 │ 15647311 │ Hill │ 608 │ Spain │ Female │ 41 │ 1 │ 83807.9 │ 1 │ 0 │ 1 │ 1.12543e5 │ 0 │\n│ 3 │ 15619304 │ Onio │ 502 │ France │ Female │ 42 │ 8 │ 1.59661e5 │ 3 │ 1 │ 0 │ 1.13932e5 │ 1 │\n│ 4 │ 15701354 │ Boni │ 699 │ France │ Female │ 39 │ 1 │ 0.0 │ 2 │ 0 │ 0 │ 93826.6 │ 0 │\n│ 5 │ 15737888 │ Mitchell │ 850 │ Spain │ Female │ 43 │ 2 │ 1.25511e5 │ 1 │ 1 │ 1 │ 79084.1 │ 0 │\n└───────────┴────────────┴──────────┴─────────────┴───────────┴─────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's get rid of useless columns such as RowNumber and CustomerId","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"df = df[:, Not([:Surname, :RowNumber, :CustomerId])]\n\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌─────────────┬───────────┬─────────┬───────┬────────┬────────────┬───────────────┬───────────┬────────────────┬─────────────────┬────────┐\n│ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │ Exited │\n│ Int64 │ String7 │ String7 │ Int64 │ Int64 │ Float64 │ Int64 │ Int64 │ Int64 │ Float64 │ Int64 │\n│ Count │ Textual │ Textual │ Count │ Count │ Continuous │ Count │ Count │ Count │ Continuous │ Count │\n├─────────────┼───────────┼─────────┼───────┼────────┼────────────┼───────────────┼───────────┼────────────────┼─────────────────┼────────┤\n│ 619 │ France │ Female │ 42 │ 2 │ 0.0 │ 1 │ 1 │ 1 │ 1.01349e5 │ 1 │\n│ 608 │ Spain │ Female │ 41 │ 1 │ 83807.9 │ 1 │ 0 │ 1 │ 1.12543e5 │ 0 │\n│ 502 │ France │ Female │ 42 │ 8 │ 1.59661e5 │ 3 │ 1 │ 0 │ 1.13932e5 │ 1 │\n│ 699 │ France │ Female │ 39 │ 1 │ 0.0 │ 2 │ 0 │ 0 │ 93826.6 │ 0 │\n│ 850 │ Spain │ Female │ 43 │ 2 │ 1.25511e5 │ 1 │ 1 │ 1 │ 79084.1 │ 0 │\n└─────────────┴───────────┴─────────┴───────┴────────┴────────────┴───────────────┴───────────┴────────────────┴─────────────────┴────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Coercing-Data","page":"SMOTENC on Customer Churn Data","title":"Coercing Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's coerce the nominal data to Multiclass, the ordinal data to OrderedFactor and the continuous data to Continuous.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"df = coerce(df, \n :Geography => Multiclass, \n :Gender=> Multiclass,\n :CreditScore => OrderedFactor,\n :Age => OrderedFactor,\n :Tenure => OrderedFactor,\n :Balance => Continuous,\n :NumOfProducts => OrderedFactor,\n :HasCrCard => Multiclass,\n :IsActiveMember => Multiclass,\n :EstimatedSalary => Continuous,\n :Exited => Multiclass\n )\n\nScientificTypes.schema(df)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌─────────────────┬────────────────────┬───────────────────────────────────┐\n│ names │ scitypes │ types │\n├─────────────────┼────────────────────┼───────────────────────────────────┤\n│ CreditScore │ OrderedFactor{460} │ CategoricalValue{Int64, UInt32} │\n│ Geography │ Multiclass{3} │ CategoricalValue{String7, UInt32} │\n│ Gender │ Multiclass{2} │ CategoricalValue{String7, UInt32} │\n│ Age │ OrderedFactor{70} │ CategoricalValue{Int64, UInt32} │\n│ Tenure │ OrderedFactor{11} │ CategoricalValue{Int64, UInt32} │\n│ Balance │ Continuous │ Float64 │\n│ NumOfProducts │ OrderedFactor{4} │ CategoricalValue{Int64, UInt32} │\n│ HasCrCard │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ IsActiveMember │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ EstimatedSalary │ Continuous │ Float64 │\n│ Exited │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n└─────────────────┴────────────────────┴───────────────────────────────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Unpacking-and-Splitting-Data","page":"SMOTENC on Customer Churn Data","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"y, X = unpack(df, ==(:Exited); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌─────────────────────────────────┬───────────────────────────────────┬───────────────────────────────────┬─────────────────────────────────┬─────────────────────────────────┬────────────┬─────────────────────────────────┬─────────────────────────────────┬─────────────────────────────────┬─────────────────┐\n│ CreditScore │ Geography │ Gender │ Age │ Tenure │ Balance │ NumOfProducts │ HasCrCard │ IsActiveMember │ EstimatedSalary │\n│ CategoricalValue{Int64, UInt32} │ CategoricalValue{String7, UInt32} │ CategoricalValue{String7, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ Float64 │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ Float64 │\n│ OrderedFactor{460} │ Multiclass{3} │ Multiclass{2} │ OrderedFactor{70} │ OrderedFactor{11} │ Continuous │ OrderedFactor{4} │ Multiclass{2} │ Multiclass{2} │ Continuous │\n├─────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────┼─────────────────────────────────┼─────────────────────────────────┼────────────┼─────────────────────────────────┼─────────────────────────────────┼─────────────────────────────────┼─────────────────┤\n│ 669 │ France │ Female │ 31 │ 6 │ 1.13001e5 │ 1 │ 1 │ 0 │ 40467.8 │\n│ 822 │ France │ Male │ 37 │ 3 │ 105563.0 │ 1 │ 1 │ 0 │ 1.82625e5 │\n│ 423 │ France │ Female │ 36 │ 5 │ 97665.6 │ 1 │ 1 │ 0 │ 1.18373e5 │\n│ 623 │ France │ Male │ 21 │ 10 │ 0.0 │ 2 │ 0 │ 1 │ 1.35851e5 │\n│ 691 │ Germany │ Female │ 37 │ 7 │ 1.23068e5 │ 1 │ 1 │ 1 │ 98162.4 │\n└─────────────────────────────────┴───────────────────────────────────┴───────────────────────────────────┴─────────────────────────────────┴─────────────────────────────────┴────────────┴─────────────────────────────────┴─────────────────────────────────┴─────────────────────────────────┴─────────────────┘","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"train_inds, test_inds = partition(eachindex(y), 0.8, shuffle=true, \n rng=Random.Xoshiro(42))\nX_train, X_test = X[train_inds, :], X[test_inds, :]\ny_train, y_test = y[train_inds], y[test_inds]","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"(CategoricalValue{Int64, UInt32}[0, 1, 1, 0, 0, 0, 0, 0, 0, 0 … 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 1, 1, 0, 0, 0 … 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Oversampling","page":"SMOTENC on Customer Churn Data","title":"Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇ 2037 (25.6%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 7963 (100.0%)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Looks like we have a class imbalance problem. Let's oversample with SMOTE-NC and set the desired ratios so that the positive minority class is 90% of the majority class","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Xover, yover = smotenc(X_train, y_train; k=3, ratios=Dict(1=>0.9), rng=42)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"(12109×10 DataFrame\n Row │ CreditScore Geography Gender Age Tenure Balance NumOfPr ⋯\n │ Cat… Cat… Cat… Cat… Cat… Float64 Cat… ⋯\n───────┼────────────────────────────────────────────────────────────────────────\n 1 │ 551 France Female 38 10 0.0 2 ⋯\n 2 │ 676 France Female 37 5 89634.7 1\n 3 │ 543 France Male 42 4 89838.7 3\n 4 │ 663 France Male 34 10 0.0 1\n 5 │ 621 Germany Female 34 2 91258.5 2 ⋯\n 6 │ 723 France Male 28 4 0.0 2\n 7 │ 735 France Female 21 1 1.78718e5 2\n 8 │ 501 France Male 35 6 99760.8 1\n ⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱\n 12103 │ 551 France Female 40 2 1.68002e5 1 ⋯\n 12104 │ 716 France Female 46 2 1.09379e5 2\n 12105 │ 850 Spain Female 45 10 1.66777e5 1\n 12106 │ 785 France Female 39 9 1.33118e5 1\n 12107 │ 565 Germany Female 39 5 1.44874e5 1 ⋯\n 12108 │ 510 Germany Male 43 0 1.38862e5 1\n 12109 │ 760 France Female 41 2 113419.0 1\n 4 columns and 12094 rows omitted, CategoricalValue{Int64, UInt32}[0, 1, 1, 0, 0, 0, 0, 0, 0, 0 … 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"checkbalance(yover)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 5736 (90.0%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 6373 (100.0%)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Training-the-Model","page":"SMOTENC on Customer Churn Data","title":"Training the Model","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's find possible models","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"ms = models(matching(Xover, yover))","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"5-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's go for a decision tree classifier from BetaML.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"import Pkg; Pkg.add(\"BetaML\")","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Let's go for a decision tree from BetaML. We can't go for logistic regression as we did in the SMOTE tutorial because it does not support categotical features.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Before-Oversampling","page":"SMOTENC on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"# 1. Load the model\nDecisionTreeClassifier = @load DecisionTreeClassifier pkg=BetaML\n\n# 2. Instantiate it\nmodel = DecisionTreeClassifier( max_depth=4, rng=Random.Xoshiro(42))\n\n# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"import BetaML ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 4, …)\n args: \n 1:\tSource @378 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{2}}, AbstractVector{OrderedFactor{460}}, AbstractVector{OrderedFactor{70}}, AbstractVector{OrderedFactor{11}}, AbstractVector{OrderedFactor{4}}}}\n 2:\tSource @049 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#After-Oversampling","page":"SMOTENC on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(model, Xover, yover)\n\n# 4. fit the machine learning model\nfit!(mach_over)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"┌ Info: Training machine(DecisionTreeClassifier(max_depth = 4, …), …).\n└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n\n\n\ntrained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 4, …)\n args: \n 1:\tSource @033 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{2}}, AbstractVector{OrderedFactor{460}}, AbstractVector{OrderedFactor{70}}, AbstractVector{OrderedFactor{11}}, AbstractVector{OrderedFactor{4}}}}\n 2:\tSource @939 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Evaluating-the-Model","page":"SMOTENC on Customer Churn Data","title":"Evaluating the Model","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"To evaluate the model, we will use the balanced accuracy metric which equally accounts for all classes. ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Before-Oversampling-2","page":"SMOTENC on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"y_pred = predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"0.57","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#After-Oversampling-2","page":"SMOTENC on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"y_pred_over = predict_mode(mach_over, X_test)\n\nscore = round(balanced_accuracy(y_pred_over, y_test), digits=2)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"0.7","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Although the results do get better compared to when we just used SMOTE, it may hold in this case that the extra categorical features we took into account are not be that important. The difference may be attributed to the decision tree.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Evaluating-the-Model-Revisited","page":"SMOTENC on Customer Churn Data","title":"Evaluating the Model - Revisited","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"We have previously evaluated the model using a single point estimate of the balanced accuracy resulting in a 13% improvement. A more precise evaluation would use cross validation to combine many different point estimates into a more precise one (their average). The standard deviation among such point estimates also allows us to quantify the uncertainty of the estimate; a smaller standard deviation would imply a smaller confidence interval at the same probability.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#Before-Oversampling-3","page":"SMOTENC on Customer Churn Data","title":"Before Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:02:54\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.565 │ 0.00623 │ [0.568, 0.554, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Before oversampling, and assuming that the balanced accuracy score is normally distribued we can be 95% confident that the balanced accuracy on new data is 56.5±0.62. Indeed, this agrees a lot with the original point estimate.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/#After-Oversampling-3","page":"SMOTENC on Customer Churn Data","title":"After Oversampling","text":"","category":"section"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"At first glance, this seems really nontrivial since resampling will have to be performed before training the model on each fold during cross-validation. Thankfully, the MLJBalancing helps us avoid doing this manually by offering BalancedModel where we can wrap any MLJ classification model with an arbitrary number of Imbalance.jl resamplers in a pipeline that behaves like a single MLJ model.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"In this, we must construct the resampling model via it's MLJ interface then pass it along with the classification model to BalancedModel.","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"# 2. Instantiate the models\noversampler = Imbalance.MLJ.SMOTENC(k=3, ratios=Dict(1=>0.9), rng=42)\n\n# 2.1 Wrap them in one model\nbalanced_model = BalancedModel(model=model, balancer1=oversampler)\n\n# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train, scitype_check_level=0)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = DecisionTreeClassifier(max_depth = 4, …), …)\n args: \n 1:\tSource @967 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{3}}, AbstractVector{Multiclass{2}}, AbstractVector{OrderedFactor{460}}, AbstractVector{OrderedFactor{70}}, AbstractVector{OrderedFactor{11}}, AbstractVector{OrderedFactor{4}}}}\n 2:\tSource @394 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"We can easily confirm that this is equivalent to what we had earlier","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"predict_mode(mach_over, X_test) == y_pred_over","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"true","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Now let's cross-validate","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:07:24\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.677 │ 0.0124 │ [0.678, 0.688, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/smotenc_churn_dataset/smotenc_churn_dataset/","page":"SMOTENC on Customer Churn Data","title":"SMOTENC on Customer Churn Data","text":"Fair enough. After oversampling the interval under the same assumptions is 67.7±1.2% which is still a meaningful improvement over 56.5±0.62 that we had prior to oversampling ot the 55.2±1.5% that we had with logistic regression and SMOTE in an earlier example.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Balanced-Bagging-for-Cerebral-Stroke-Prediction","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Random\nusing CSV\nusing DataFrames\nusing MLJ\nusing Imbalance\nusing MLJBalancing\nusing StatsBase\nusing ScientificTypes\nusing Plots, Measures\nusing Impute\nusing HTTP: download","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Loading-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Loading Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"In this example, we will consider the Cerebral Stroke Prediction Dataset found on Kaggle for the objective of predicting where a stroke has occurred given medical features about patients.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/cerebral_ensemble/cerebral.csv\", \"./\")\ndf = CSV.read(\"./cerebral.csv\", DataFrame)\n\n# Display the first 5 rows with DataFrames\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────┬─────────┬────────────┬──────────────┬───────────────┬──────────────┬──────────────┬────────────────┬───────────────────┬────────────────────────────┬──────────────────────────┬────────┐\n│ id │ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │ stroke │\n│ Int64 │ String7 │ Float64 │ Int64 │ Int64 │ String3 │ String15 │ String7 │ Float64 │ Union{Missing, Float64} │ Union{Missing, String15} │ Int64 │\n│ Count │ Textual │ Continuous │ Count │ Count │ Textual │ Textual │ Textual │ Continuous │ Union{Missing, Continuous} │ Union{Missing, Textual} │ Count │\n├───────┼─────────┼────────────┼──────────────┼───────────────┼──────────────┼──────────────┼────────────────┼───────────────────┼────────────────────────────┼──────────────────────────┼────────┤\n│ 30669 │ Male │ 3.0 │ 0 │ 0 │ No │ children │ Rural │ 95.12 │ 18.0 │ missing │ 0 │\n│ 30468 │ Male │ 58.0 │ 1 │ 0 │ Yes │ Private │ Urban │ 87.96 │ 39.2 │ never smoked │ 0 │\n│ 16523 │ Female │ 8.0 │ 0 │ 0 │ No │ Private │ Urban │ 110.89 │ 17.6 │ missing │ 0 │\n│ 56543 │ Female │ 70.0 │ 0 │ 0 │ Yes │ Private │ Rural │ 69.04 │ 35.9 │ formerly smoked │ 0 │\n│ 46136 │ Male │ 14.0 │ 0 │ 0 │ No │ Never_worked │ Rural │ 161.28 │ 19.1 │ missing │ 0 │\n└───────┴─────────┴────────────┴──────────────┴───────────────┴──────────────┴──────────────┴────────────────┴───────────────────┴────────────────────────────┴──────────────────────────┴────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"It's obvious that the id column is useless for predictions so we may as well drop it.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"df = df[:, Not(:id)]\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌─────────┬────────────┬──────────────┬───────────────┬──────────────┬──────────────┬────────────────┬───────────────────┬────────────────────────────┬──────────────────────────┬────────┐\n│ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │ stroke │\n│ String7 │ Float64 │ Int64 │ Int64 │ String3 │ String15 │ String7 │ Float64 │ Union{Missing, Float64} │ Union{Missing, String15} │ Int64 │\n│ Textual │ Continuous │ Count │ Count │ Textual │ Textual │ Textual │ Continuous │ Union{Missing, Continuous} │ Union{Missing, Textual} │ Count │\n├─────────┼────────────┼──────────────┼───────────────┼──────────────┼──────────────┼────────────────┼───────────────────┼────────────────────────────┼──────────────────────────┼────────┤\n│ Male │ 3.0 │ 0 │ 0 │ No │ children │ Rural │ 95.12 │ 18.0 │ missing │ 0 │\n│ Male │ 58.0 │ 1 │ 0 │ Yes │ Private │ Urban │ 87.96 │ 39.2 │ never smoked │ 0 │\n│ Female │ 8.0 │ 0 │ 0 │ No │ Private │ Urban │ 110.89 │ 17.6 │ missing │ 0 │\n│ Female │ 70.0 │ 0 │ 0 │ Yes │ Private │ Rural │ 69.04 │ 35.9 │ formerly smoked │ 0 │\n│ Male │ 14.0 │ 0 │ 0 │ No │ Never_worked │ Rural │ 161.28 │ 19.1 │ missing │ 0 │\n└─────────┴────────────┴──────────────┴───────────────┴──────────────┴──────────────┴────────────────┴───────────────────┴────────────────────────────┴──────────────────────────┴────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Visualize-the-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Visualize the Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Since this dataset is composed mostly of categorical features, a bar chart for each categorical column is a good way to visualize the data.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# Create a bar chart for each column\nbar_charts = []\nfor col in names(df)\n counts = countmap(df[!, col])\n k, v = collect(keys(counts)), collect(values(counts))\n if length(k) < 20\n push!(bar_charts, bar(k, v, legend=false, title=col, color=\"turquoise3\", xrotation=90, margin=6mm))\n end\nend\n\n# Combine bar charts into a grid layout with specified plot size\nplot_res = plot(bar_charts..., layout=(3, 4),\n size=(1300, 500),\n dpi=200\n )\nsavefig(plot_res, \"./assets/cerebral-charts.png\")\n","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"(Image: Mushroom Features Plots)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Our target her is the Stroke variable; notice how imbalanced it is.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Coercing-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Coercing Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Typical models from MLJ assume that elements in each column of a table have some scientific type as defined by the ScientificTypes.jl package. It's often necessary to coerce the types found by default to the appropriate type.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────┬────────────────────────────┬──────────────────────────┐\n│ names │ scitypes │ types │\n├───────────────────┼────────────────────────────┼──────────────────────────┤\n│ gender │ Textual │ String7 │\n│ age │ Continuous │ Float64 │\n│ hypertension │ Count │ Int64 │\n│ heart_disease │ Count │ Int64 │\n│ ever_married │ Textual │ String3 │\n│ work_type │ Textual │ String15 │\n│ Residence_type │ Textual │ String7 │\n│ avg_glucose_level │ Continuous │ Float64 │\n│ bmi │ Union{Missing, Continuous} │ Union{Missing, Float64} │\n│ smoking_status │ Union{Missing, Textual} │ Union{Missing, String15} │\n│ stroke │ Count │ Int64 │\n└───────────────────┴────────────────────────────┴──────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"For instance, here we need to coerce all the data to Multiclass as they are all nominal variables except for Age, avg_glucose_level and bmi which we can treat as continuous","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"df = coerce(df, :gender => Multiclass, :age => Continuous, :hypertension => Multiclass,\n\t:heart_disease => Multiclass, :ever_married => Multiclass, :work_type => Multiclass,\n\t:Residence_type => Multiclass, :avg_glucose_level => Continuous,\n\t:bmi => Continuous, :smoking_status => Multiclass, :stroke => Multiclass,\n)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────┬───────────────┬────────────────────────────────────┐\n│ names │ scitypes │ types │\n├───────────────────┼───────────────┼────────────────────────────────────┤\n│ gender │ Multiclass{3} │ CategoricalValue{String7, UInt32} │\n│ age │ Continuous │ Float64 │\n│ hypertension │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ heart_disease │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n│ ever_married │ Multiclass{2} │ CategoricalValue{String3, UInt32} │\n│ work_type │ Multiclass{5} │ CategoricalValue{String15, UInt32} │\n│ Residence_type │ Multiclass{2} │ CategoricalValue{String7, UInt32} │\n│ avg_glucose_level │ Continuous │ Float64 │\n│ bmi │ Continuous │ Float64 │\n│ smoking_status │ Multiclass{3} │ CategoricalValue{String15, UInt32} │\n│ stroke │ Multiclass{2} │ CategoricalValue{Int64, UInt32} │\n└───────────────────┴───────────────┴────────────────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"As shown in the types, some columns have missing values we will impute them using simple random sampling as dropping their rows would mean that we lose a big chunk of the dataset.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"df = Impute.srs(df); disallowmissing!(df)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────────────────────┬────────────┬─────────────────────────────────┬─────────────────────────────────┬───────────────────────────────────┬────────────────────────────────────┬───────────────────────────────────┬───────────────────┬────────────┬────────────────────────────────────┬─────────────────────────────────┐\n│ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │ stroke │\n│ CategoricalValue{String7, UInt32} │ Float64 │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{String3, UInt32} │ CategoricalValue{String15, UInt32} │ CategoricalValue{String7, UInt32} │ Float64 │ Float64 │ CategoricalValue{String15, UInt32} │ CategoricalValue{Int64, UInt32} │\n│ Multiclass{3} │ Continuous │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{5} │ Multiclass{2} │ Continuous │ Continuous │ Multiclass{3} │ Multiclass{2} │\n├───────────────────────────────────┼────────────┼─────────────────────────────────┼─────────────────────────────────┼───────────────────────────────────┼────────────────────────────────────┼───────────────────────────────────┼───────────────────┼────────────┼────────────────────────────────────┼─────────────────────────────────┤\n│ Male │ 3.0 │ 0 │ 0 │ No │ children │ Rural │ 95.12 │ 18.0 │ formerly smoked │ 0 │\n│ Male │ 58.0 │ 1 │ 0 │ Yes │ Private │ Urban │ 87.96 │ 39.2 │ never smoked │ 0 │\n│ Female │ 8.0 │ 0 │ 0 │ No │ Private │ Urban │ 110.89 │ 17.6 │ never smoked │ 0 │\n│ Female │ 70.0 │ 0 │ 0 │ Yes │ Private │ Rural │ 69.04 │ 35.9 │ formerly smoked │ 0 │\n│ Male │ 14.0 │ 0 │ 0 │ No │ Never_worked │ Rural │ 161.28 │ 19.1 │ formerly smoked │ 0 │\n└───────────────────────────────────┴────────────┴─────────────────────────────────┴─────────────────────────────────┴───────────────────────────────────┴────────────────────────────────────┴───────────────────────────────────┴───────────────────┴────────────┴────────────────────────────────────┴─────────────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Unpacking-and-Splitting-Data","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"y, X = unpack(df, ==(:stroke); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"┌───────────────────────────────────┬────────────┬─────────────────────────────────┬─────────────────────────────────┬───────────────────────────────────┬────────────────────────────────────┬───────────────────────────────────┬───────────────────┬────────────┬────────────────────────────────────┐\n│ gender │ age │ hypertension │ heart_disease │ ever_married │ work_type │ Residence_type │ avg_glucose_level │ bmi │ smoking_status │\n│ CategoricalValue{String7, UInt32} │ Float64 │ CategoricalValue{Int64, UInt32} │ CategoricalValue{Int64, UInt32} │ CategoricalValue{String3, UInt32} │ CategoricalValue{String15, UInt32} │ CategoricalValue{String7, UInt32} │ Float64 │ Float64 │ CategoricalValue{String15, UInt32} │\n│ Multiclass{3} │ Continuous │ Multiclass{2} │ Multiclass{2} │ Multiclass{2} │ Multiclass{5} │ Multiclass{2} │ Continuous │ Continuous │ Multiclass{3} │\n├───────────────────────────────────┼────────────┼─────────────────────────────────┼─────────────────────────────────┼───────────────────────────────────┼────────────────────────────────────┼───────────────────────────────────┼───────────────────┼────────────┼────────────────────────────────────┤\n│ Female │ 37.0 │ 0 │ 0 │ Yes │ Private │ Urban │ 103.66 │ 36.1 │ smokes │\n│ Female │ 78.0 │ 0 │ 0 │ No │ Private │ Rural │ 83.97 │ 39.6 │ formerly smoked │\n│ Female │ 2.0 │ 0 │ 0 │ No │ children │ Urban │ 98.66 │ 17.0 │ smokes │\n│ Female │ 62.0 │ 0 │ 0 │ No │ Private │ Rural │ 205.41 │ 27.8 │ smokes │\n│ Male │ 14.0 │ 0 │ 0 │ No │ Private │ Rural │ 118.18 │ 24.5 │ never smoked │\n└───────────────────────────────────┴────────────┴─────────────────────────────────┴─────────────────────────────────┴───────────────────────────────────┴────────────────────────────────────┴───────────────────────────────────┴───────────────────┴────────────┴────────────────────────────────────┘","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Splitting the data into train and test portions is also easy using MLJ's partition function. stratify=y guarantees that the data is distributed in the same proportions as the original dataset in both splits which is more representative of the real world.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"(X_train, X_test), (y_train, y_test) = partition(\n\t(X, y),\n\t0.8,\n\tmulti = true,\n\tshuffle = true,\n\tstratify = y,\n\trng = Random.Xoshiro(42)\n)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"⚠️ Always split the data before oversampling. If your test data has oversampled observations then train-test contamination has occurred; novel observations will not come from the oversampling function.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Oversampling","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Oversampling","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"It was obvious from the bar charts that there is a severe imbalance problem. Let's look at that again.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"1: ▇ 783 (1.8%) \n0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 42617 (100.0%)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Indeed, may be too severe for most models.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Training-the-Model","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Training the Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Because we have scientific types setup, we can easily check what models will be able to train on our data. This should guarantee that the model we choose won't throw an error due to types after feeding it the data.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"ms = models(matching(Xover, yover))","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"6-element Vector{NamedTuple{(:name, :package_name, :is_supervised, :abstract_type, :deep_properties, :docstring, :fit_data_scitype, :human_name, :hyperparameter_ranges, :hyperparameter_types, :hyperparameters, :implemented_methods, :inverse_transform_scitype, :is_pure_julia, :is_wrapper, :iteration_parameter, :load_path, :package_license, :package_url, :package_uuid, :predict_scitype, :prediction_type, :reporting_operations, :reports_feature_importances, :supports_class_weights, :supports_online, :supports_training_losses, :supports_weights, :transform_scitype, :input_scitype, :target_scitype, :output_scitype)}}:\n (name = CatBoostClassifier, package_name = CatBoost, ... )\n (name = ConstantClassifier, package_name = MLJModels, ... )\n (name = DecisionTreeClassifier, package_name = BetaML, ... )\n (name = DeterministicConstantClassifier, package_name = MLJModels, ... )\n (name = OneRuleClassifier, package_name = OneRule, ... )\n (name = RandomForestClassifier, package_name = BetaML, ... )","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Let's go for a DecisionTreeClassifier","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"import Pkg; Pkg.add(\"BetaML\")","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":" Resolving package versions...\n Installed MLJBalancing ─ v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Project.toml`\n [45f359ea] + MLJBalancing v0.1.0\n Updating `~/Documents/GitHub/Imbalance.jl/docs/Manifest.toml`\n [45f359ea] + MLJBalancing v0.1.0\nPrecompiling project...\n ✓ MLJBalancing\n 1 dependency successfully precompiled in 25 seconds. 262 already precompiled.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Load-and-Construct","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Load and Construct","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# 1. Load the model\nDecisionTreeClassifier = @load DecisionTreeClassifier pkg=BetaML\n\n# 2. Instantiate it\nmodel = DecisionTreeClassifier(max_depth=4)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"import BetaML ✔\n\n\n┌ Info: For silent loading, specify `verbosity=0`. \n└ @ Main /Users/essam/.julia/packages/MLJModels/EkXIe/src/loading.jl:159\n\n\n\nDecisionTreeClassifier(\n max_depth = 4, \n min_gain = 0.0, \n min_records = 2, \n max_features = 0, \n splitting_criterion = BetaML.Utils.gini, \n rng = Random._GLOBAL_RNG())","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Wrap-in-a-machine-and-fit!","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Wrap in a machine and fit!","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach, verbosity=0)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"trained Machine; caches model-specific representations of data\n model: DecisionTreeClassifier(max_depth = 4, …)\n args: \n 1:\tSource @245 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{3}}}}\n 2:\tSource @251 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Evaluate-the-Model","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Evaluate the Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"y_pred = MLJ.predict_mode(mach, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"0.5","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Training-BalancedBagging-Model","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Training BalancedBagging Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"The results suggest that the model is just as good as random guessing. Let's see if this gets better by using a BalancedBaggingClassifier. This classifier trains T of the given model on T undersampled versions of the dataset where in each undersampled version there are as much majority examples as there are minority examples.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"This approach can allow us to workaround the imbalance issue without losing any data. For instance, if we set T=Int(100/1.8) (which is the default) then on average all majority examples will be used in one of the T bags.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Load-and-Construct-2","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Load and Construct","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"bagging_model = BalancedBaggingClassifier(model=model, T=30, rng=Random.Xoshiro(42))","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"BalancedBaggingClassifier(\n model = DecisionTreeClassifier(\n max_depth = 4, \n min_gain = 0.0, \n min_records = 2, \n max_features = 0, \n splitting_criterion = BetaML.Utils.gini, \n rng = Random._GLOBAL_RNG()), \n T = 30, \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1))","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Wrap-in-a-machine-and-fit!-2","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Wrap in a machine and fit!","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(bagging_model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"trained Machine; does not cache data\n model: BalancedBaggingClassifier(model = DecisionTreeClassifier(max_depth = 4, …), …)\n args: \n 1:\tSource @005 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Multiclass{5}}, AbstractVector{Multiclass{2}}, AbstractVector{Multiclass{3}}}}\n 2:\tSource @531 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/#Evaluate-the-Model-2","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Evaluate the Model","text":"","category":"section"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"y_pred = MLJ.predict_mode(mach_over, X_test) \n\nscore = round(balanced_accuracy(y_pred, y_test), digits=2)","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"0.77","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"This is a dramatic improvement over what we had before. Let's confirm with cross-validation.","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy, operation=predict_mode) ","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Evaluating over 10 folds: 100%[=========================] Time: 0:01:40\u001b[K\n\n\n\nPerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.772 │ 0.0146 │ [0.738, 0.769, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/cerebral_ensemble/cerebral_ensemble/","page":"Balanced Bagging for Cerebral Stroke Prediction","title":"Balanced Bagging for Cerebral Stroke Prediction","text":"Under the normality of scores, the 95% confidence interval is 77.2±1.4% for the balanced accuracy.","category":"page"},{"location":"examples/","page":"More Examples","title":"More Examples","text":" \n \n\n","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#SMOTE-Tomek-for-Ethereum-Fraud-Detection","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"import Pkg;\nPkg.add([\"Random\", \"CSV\", \"DataFrames\", \"MLJ\", \"Imbalance\", \"MLJBalancing\", \n \"ScientificTypes\",\"Impute\", \"StatsBase\", \"Plots\", \"Measures\", \"HTTP\"])\n\nusing Imbalance\nusing MLJBalancing\nusing CSV\nusing DataFrames\nusing ScientificTypes\nusing CategoricalArrays\nusing MLJ\nusing Plots\nusing Random\nusing Impute\nusing HTTP: download","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Loading-Data","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Loading Data","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"In this example, we will consider the Ethereum Fraud Detection Dataset found on Kaggle where the objective is to predict whether an Ethereum transaction is fraud or not (called FLAG) given some features about the transaction.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"CSV gives us the ability to easily read the dataset after it's downloaded as follows","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"download(\"https://raw.githubusercontent.com/JuliaAI/Imbalance.jl/dev/docs/src/examples/fraud_detection/transactions.csv\", \"./\")\n\ndf = CSV.read(\"./transactions.csv\", DataFrame)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"There are plenty of useless columns that we can get rid of such as Column1, Index and probably, Address. We also have to get rid of the categorical features because SMOTE won't be able to deal with those and it leaves us with more options for the model.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"df = df[:,\n\tNot([\n\t\t:Column1,\n\t\t:Index,\n\t\t:Address,\n\t\tSymbol(\" ERC20 most sent token type\"),\n\t\tSymbol(\" ERC20_most_rec_token_type\"),\n\t]),\n] \nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"If you scroll through the printed data frame, you find that some columns also have Missing for their element type, meaning that they may be containing missing values. We will use linear interpolation, last-observation carried forward and next observation carried backward techniques to fill up the missing values. This will allow us to call disallowmissing!(df) to return a dataframe where Missing is not an element type for any column.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"df = Impute.interp(df) |> Impute.locf() |> Impute.nocb(); disallowmissing!(df)\nfirst(df, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Coercing-Data","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Coercing Data","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Let's look at the schema first","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"ScientificTypes.schema(df)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"┌──────────────────────────────────────────────────────┬────────────┬─────────┐\n│ names │ scitypes │ types │\n├──────────────────────────────────────────────────────┼────────────┼─────────┤\n│ FLAG │ Count │ Int64 │\n│ Avg min between sent tnx │ Continuous │ Float64 │\n│ Avg min between received tnx │ Continuous │ Float64 │\n│ Time Diff between first and last (Mins) │ Continuous │ Float64 │\n│ Sent tnx │ Count │ Int64 │\n│ Received Tnx │ Count │ Int64 │\n│ Number of Created Contracts │ Count │ Int64 │\n│ Unique Received From Addresses │ Count │ Int64 │\n│ Unique Sent To Addresses │ Count │ Int64 │\n│ min value received │ Continuous │ Float64 │\n│ max value received │ Continuous │ Float64 │\n│ avg val received │ Continuous │ Float64 │\n│ min val sent │ Continuous │ Float64 │\n│ max val sent │ Continuous │ Float64 │\n│ avg val sent │ Continuous │ Float64 │\n│ min value sent to contract │ Continuous │ Float64 │\n│ ⋮ │ ⋮ │ ⋮ │\n└──────────────────────────────────────────────────────┴────────────┴─────────┘\n 30 rows omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"The FLAG target should definitely be Multiclass, the rest seems fine.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"df = coerce(df, :FLAG =>Multiclass)\nScientificTypes.schema(df)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"┌──────────────────────────────────────────────────────┬───────────────┬────────\n│ names │ scitypes │ types ⋯\n├──────────────────────────────────────────────────────┼───────────────┼────────\n│ FLAG │ Multiclass{2} │ Categ ⋯\n│ Avg min between sent tnx │ Continuous │ Float ⋯\n│ Avg min between received tnx │ Continuous │ Float ⋯\n│ Time Diff between first and last (Mins) │ Continuous │ Float ⋯\n│ Sent tnx │ Count │ Int64 ⋯\n│ Received Tnx │ Count │ Int64 ⋯\n│ Number of Created Contracts │ Count │ Int64 ⋯\n│ Unique Received From Addresses │ Count │ Int64 ⋯\n│ Unique Sent To Addresses │ Count │ Int64 ⋯\n│ min value received │ Continuous │ Float ⋯\n│ max value received │ Continuous │ Float ⋯\n│ avg val received │ Continuous │ Float ⋯\n│ min val sent │ Continuous │ Float ⋯\n│ max val sent │ Continuous │ Float ⋯\n│ avg val sent │ Continuous │ Float ⋯\n│ min value sent to contract │ Continuous │ Float ⋯\n│ ⋮ │ ⋮ │ ⋱\n└──────────────────────────────────────────────────────┴───────────────┴────────\n 1 column and 30 rows omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Unpacking-and-Splitting-Data","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Unpacking and Splitting Data","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Both MLJ and the pure functional interface of Imbalance assume that the observations table X and target vector y are separate. We can accomplish that by using unpack from MLJ","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"y, X = unpack(df, ==(:FLAG); rng=123);\nfirst(X, 5) |> pretty","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Splitting the data into train and test portions is also easy using MLJ's partition function.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"(X_train, X_test), (y_train, y_test) = partition(\n\t(X, y),\n\t0.8,\n\tmulti = true,\n\tshuffle = true,\n\tstratify = y,\n\trng = Random.Xoshiro(41)\n)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Resampling","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Resampling","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Before deciding to oversample, let's see how adverse is the imbalance problem, if it exists. Ideally, you may as well check if the classification model is robust to this problem.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"checkbalance(y) # comes from Imbalance","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"This signals a potential class imbalance problem. Let's consider using SMOTE-Tomek to resample this data. The SMOTE-Tomek algorithm is nothing but SMOTE followed by TomekUndersampler. We can wrap these in a pipeline along with a classification model for predictions using BalancedModel from MLJBalancing. Let's go for a RandomForestClassifier from DecisionTree.jl for the model.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"import Pkg; Pkg.add(\"DecisionTree\")","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Construct-the-Resampling-and-Classification-Models","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Construct the Resampling & Classification Models","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"oversampler = Imbalance.MLJ.SMOTE(ratios=Dict(1=>0.5), rng=Random.Xoshiro(42))\nundersampler = Imbalance.MLJ.TomekUndersampler(min_ratios=Dict(0=>1.3), force_min_ratios=true)\nRandomForestClassifier = @load RandomForestClassifier pkg=DecisionTree\nmodel = RandomForestClassifier(n_trees=2, rng=Random.Xoshiro(42))","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"RandomForestClassifier(\n max_depth = -1, \n min_samples_leaf = 1, \n min_samples_split = 2, \n min_purity_increase = 0.0, \n n_subfeatures = -1, \n n_trees = 2, \n sampling_fraction = 0.7, \n feature_importance = :impurity, \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1))","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Form-the-Pipeline-using-BalancedModel","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Form the Pipeline using BalancedModel","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"balanced_model = BalancedModel(model=model, balancer1=oversampler, balancer2=undersampler)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"BalancedModelProbabilistic(\n model = RandomForestClassifier(\n max_depth = -1, \n min_samples_leaf = 1, \n min_samples_split = 2, \n min_purity_increase = 0.0, \n n_subfeatures = -1, \n n_trees = 2, \n sampling_fraction = 0.7, \n feature_importance = :impurity, \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1)), \n balancer1 = SMOTE(\n k = 5, \n ratios = Dict(1 => 0.5), \n rng = Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1), \n try_preserve_type = true), \n balancer2 = TomekUndersampler(\n min_ratios = Dict(0 => 1.3), \n force_min_ratios = true, \n rng = TaskLocalRNG(), \n try_preserve_type = true))","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Now we can treat balanced_model like any MLJ model.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Fit-the-BalancedModel","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Fit the BalancedModel","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"# 3. Wrap it with the data in a machine\nmach_over = machine(balanced_model, X_train, y_train)\n\n# 4. fit the machine learning model\nfit!(mach_over, verbosity=0)","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"trained Machine; does not cache data\n model: BalancedModelProbabilistic(model = RandomForestClassifier(max_depth = -1, …), …)\n args: \n 1:\tSource @967 ⏎ Table{Union{AbstractVector{Continuous}, AbstractVector{Count}}}\n 2:\tSource @913 ⏎ AbstractVector{Multiclass{2}}","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Validate-the-BalancedModel","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Validate the BalancedModel","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"cv=CV(nfolds=10)\nevaluate!(mach_over, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"PerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.93 │ 0.00757 │ [0.927, 0.936, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/#Compare-with-RandomForestClassifier-only","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"Compare with RandomForestClassifier only","text":"","category":"section"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"To see if this represents any form of improvement, fitting and validating the original model by itself.","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"# 3. Wrap it with the data in a machine\nmach = machine(model, X_train, y_train, scitype_check_level=0)\nfit!(mach)\n\nevaluate!(mach, resampling=cv, measure=balanced_accuracy) ","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"PerformanceEvaluation object with these fields:\n model, measure, operation, measurement, per_fold,\n per_observation, fitted_params_per_fold,\n report_per_fold, train_test_rows, resampling, repeats\nExtract:\n┌─────────────────────┬──────────────┬─────────────┬─────────┬──────────────────\n│ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯\n├─────────────────────┼──────────────┼─────────────┼─────────┼──────────────────\n│ BalancedAccuracy( │ predict_mode │ 0.908 │ 0.00932 │ [0.903, 0.898, ⋯\n│ adjusted = false) │ │ │ │ ⋯\n└─────────────────────┴──────────────┴─────────────┴─────────┴──────────────────\n 1 column omitted","category":"page"},{"location":"examples/fraud_detection/fraud_detection/","page":"SMOTE-Tomek for Ethereum Fraud Detection","title":"SMOTE-Tomek for Ethereum Fraud Detection","text":"Assuming normal scores, the 95% confidence interval was 90.8±0.9 and after resampling it has become 93±0.7 which corresponds to a small improvement in accuracy.","category":"page"},{"location":"examples/Colab/#Google-Colab","page":"Google Colab","title":"Google Colab","text":"","category":"section"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"It is possible to run tutorials found in the examples section or API documentation on Google Colab (using provided link or icon). It should be evident how so by launching the notebook. This section describes what happens under the hood.","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"The first cell runs the following bash script to install Julia:","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"%%capture\n%%shell\nif ! command -v julia 3>&1 > /dev/null\nthen\n wget -q 'https://julialang-s3.julialang.org/bin/linux/x64/1.7/julia-1.7.2-linux-x86_64.tar.gz' \\\n -O /tmp/julia.tar.gz\n tar -x -f /tmp/julia.tar.gz -C /usr/local --strip-components 1\n rm /tmp/julia.tar.gz\nfi\njulia -e 'using Pkg; pkg\"add IJulia; precompile;\"'\necho 'Done'","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"Once that is done, one can change the runtime to Julia by choosing Runtime from the toolbar then Change runtime type and at this point they can delete the cell","category":"page"},{"location":"examples/Colab/","page":"Google Colab","title":"Google Colab","text":"Sincere thanks to Julia-on-Colab for making this possible.","category":"page"},{"location":"#Imbalance.jl","page":"Introduction","title":"Imbalance.jl","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"(Image: Imbalance)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"A Julia package with resampling methods to correct for class imbalance in a wide variety of classification settings.","category":"page"},{"location":"#Installation","page":"Introduction","title":"Installation","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"import Pkg;\nPkg.add(\"Imbalance\")","category":"page"},{"location":"#Implemented-Methods","page":"Introduction","title":"Implemented Methods","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"The package implements the following resampling algorithms","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Random Oversampling\nRandom Walk Oversampling (RWO)\nRandom Oversampling Examples (ROSE)\nSynthetic Minority Oversampling Technique (SMOTE)\nBorderline SMOTE1\nSMOTE-Nominal (SMOTE-N)\nSMOTE-Nominal Categorical (SMOTE-NC)\nRandom Undersampling\nCluster Undersampling\nEditedNearestNeighbors Undersampling\nTomek Links Undersampling\nBalanced Bagging Classifier (@MLJBalancing.jl)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"To see various examples where such methods help improve classification performance, check the tutorials sections of the documentation.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Interested in contributing with more? Check this.","category":"page"},{"location":"#Quick-Start","page":"Introduction","title":"Quick Start","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"We will illustrate using the package to oversample withSMOTE; however, all other implemented oversampling methods follow the same pattern.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Let's start by generating some dummy imbalanced data:","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using Imbalance\n\n# Set dataset properties then generate imbalanced data\nclass_probs = [0.5, 0.2, 0.3] # probability of each class \nnum_rows, num_continuous_feats = 100, 5\nX, y = generate_imbalanced_data(num_rows, num_continuous_feats; class_probs, rng=42)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"In following code blocks, it will be assumed that X and y are readily available.","category":"page"},{"location":"#Standard-API","page":"Introduction","title":"Standard API","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"All methods by default support a pure functional interface.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using Imbalance\n\n# Apply SMOTE to oversample the classes\nXover, yover = smote(X, y; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)","category":"page"},{"location":"#MLJ-Interface","page":"Introduction","title":"MLJ Interface","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"All methods support the MLJ interface where instead of directly calling the method, one instantiates a model for the method while optionally passing the keyword parameters found in the functional interface then wraps the model in a machine and follows by calling transform on the machine and data.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using MLJ\n\n# Load the model\nSMOTE = @load SMOTE pkg=Imbalance\n\n# Create an instance of the model \noversampler = SMOTE(k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\n\n# Wrap it in a machine\nmach = machine(oversampler)\n\n# Provide the data to transform \nXover, yover = transform(mach, X, y)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"All implemented oversampling methods are considered static transforms and hence, no fit is required. ","category":"page"},{"location":"#Pipelining-Models","page":"Introduction","title":"Pipelining Models","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"If MLJBalancing is also used, an arbitrary number of resampling methods from Imbalance.jl can be wrapped with a classification model from MLJ to function as a unified model where resampling automatically takes place on given data before training the model (and is bypassed during prediction).","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using MLJ, MLJBalancing\n\n# grab two resamplers and a classifier\nLogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\nSMOTE = @load SMOTE pkg=Imbalance verbosity=0\nTomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0\n\noversampler = SMOTE(k=5, ratios=1.0, rng=42)\nundersampler = TomekUndersampler(min_ratios=0.5, rng=42)\nlogistic_model = LogisticClassifier()\n\n# wrap the oversampler, undersample and classification model together\nbalanced_model = BalancedModel(model=logistic_model, \n balancer1=oversampler, balancer2=undersampler)\n\n# behaves like a single model\nmach = machine(balanced_model, X, y);\nfit!(mach, verbosity=0)\npredict(mach, X)","category":"page"},{"location":"#Table-Transforms-Interface","page":"Introduction","title":"Table Transforms Interface","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"The TableTransforms interface operates on single tables; it assumes that y is one of the columns of the given table. Thus, it follows a similar pattern to the MLJ interface except that the index of y is a required argument while instantiating the model and the data to be transformed via apply is only one table Xy.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"using Imbalance\nusing Imbalance.TableTransforms\nusing TableTransforms\n\n# Generate imbalanced data\nnum_rows = 200\nnum_features = 5\ny_ind = 3\nXy, _ = generate_imbalanced_data(num_rows, num_features; \n class_probs=[0.5, 0.2, 0.3], insert_y=y_ind, rng=42)\n\n# Initiate SMOTE model\noversampler = SMOTE(y_ind; k=5, ratios=Dict(0=>1.0, 1=> 0.9, 2=>0.8), rng=42)\nXyover = Xy |> oversampler # can chain with other table transforms \n# equivalently if TableTransforms is used\nXyover, cache = TableTransforms.apply(oversampler, Xy) # equivalently","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"The reapply(oversampler, Xy, cache) method from TableTransforms simply falls back to apply(oversample, Xy) and the revert(oversampler, Xy, cache) reverts the transform by removing the oversampled observations from the table.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Notice that because the interfaces of MLJ and TableTransforms use the same model names, you will have to specify the source of the model if both are used in the same file (e.g., Imbalance.TableTransforms.SMOTE) for the example above.","category":"page"},{"location":"#Features","page":"Introduction","title":"Features","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"Supports multi-class variants of the algorithms and both nominal and continuous features\nProvides MLJ and TableTransforms interfaces aside from the default pure functional interface\nGeneric by supporting table input/output formats as well as matrices\nSupports tables regardless to whether the target is a separate column or one of the columns\nSupports automatic encoding and decoding of nominal features","category":"page"},{"location":"#Rationale","page":"Introduction","title":"Rationale","text":"","category":"section"},{"location":"","page":"Introduction","title":"Introduction","text":"Most if not all machine learning algorithms can be viewed as a form of empirical risk minimization where the object is to find the parameters theta that for some loss function L minimize ","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"hattheta = argmin_theta frac1N sum_i=1^N L(f_theta(x_i) y_i)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"The underlying assumption is that minimizing this empirical risk corresponds to approximately minimizing the true risk which considers all examples in the populations which would imply that f_theta is approximately the true target function f that we seek to model.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"In a multi-class setting with K classes, one can write","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"hattheta = argmin_theta left( frac1N_1 sum_i in C_1 L(f_theta(x_i) y_i) + frac1N_2 sum_i in C_2 L(f_theta(x_i) y_i) + ldots + frac1N_K sum_i in C_K L(f_theta(x_i) y_i) right)","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"Class imbalance occurs when some classes have much fewer examples than other classes. In this case, the terms corresponding to smaller classes contribute minimally to the sum which makes it possible for any learning algorithm to find an approximate solution to minimizing the empirical risk that mostly only minimizes the over the significant sums. This yields a hypothesis f_theta that may be very different from the true target f with respect to the minority classes which may be the most important for the application in question.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"One obvious possible remedy is to weight the smaller sums so that a learning algorithm more easily avoids approximate solutions that exploit their insignificance which can be seen to be equivalent to repeating examples of the observations in minority classes. This can be achieved by naive random oversampling which is offered by this package along with other more advanced oversampling methods that function by generating synthetic data or deleting existing ones. You can read more about the class imbalance problem and learn about various algorithms implemented in this package by reading this series of articles on Medium.","category":"page"},{"location":"","page":"Introduction","title":"Introduction","text":"To our knowledge, there are no existing maintained Julia packages that implement resampling algorithms for multi-class classification problems or that handle both nominal and continuous features. This has served as a primary motivation for the creation of this package.","category":"page"}] }