Skip to content

Commit

Permalink
added wSAA based on RF_LGBM but boosting_type has still be specified …
Browse files Browse the repository at this point in the history
…manually
  • Loading branch information
kaiguender committed Apr 6, 2023
1 parent 7575365 commit 15ff269
Show file tree
Hide file tree
Showing 17 changed files with 1,021 additions and 349 deletions.
5 changes: 2 additions & 3 deletions _proc/01_levelSetKDEx_univariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "fcdc893f",
"metadata": {},
Expand Down Expand Up @@ -1373,9 +1372,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"display_name": "dddex",
"language": "python",
"name": "python3"
"name": "dddex"
}
},
"nbformat": 4,
Expand Down
144 changes: 135 additions & 9 deletions _proc/03_wSAA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,16 @@
"metadata": {
"language": "python"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": []
},
{
Expand Down Expand Up @@ -72,7 +81,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L22){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L23){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### RandomForestWSAA\n",
"\n",
Expand Down Expand Up @@ -119,7 +128,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L22){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L23){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### RandomForestWSAA\n",
"\n",
Expand Down Expand Up @@ -211,6 +220,83 @@
"# show_doc(RandomForestWSAA.getWeights)"
]
},
{
"cell_type": "markdown",
"id": "2811aafa",
"metadata": {},
"source": [
"## wSAA - Random Forest LightGBM"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### RandomForestWSAA_LGBM\n",
"\n",
"> RandomForestWSAA_LGBM (boosting_type:str='gbdt', num_leaves:int=31,\n",
"> max_depth:int=-1, learning_rate:float=0.1,\n",
"> n_estimators:int=100,\n",
"> subsample_for_bin:int=200000,\n",
"> objective:Union[str,Callable,NoneType]=None,\n",
"> class_weight:Union[Dict,str,NoneType]=None,\n",
"> min_split_gain:float=0.0,\n",
"> min_child_weight:float=0.001,\n",
"> min_child_samples:int=20, subsample:float=1.0,\n",
"> subsample_freq:int=0, colsample_bytree:float=1.0,\n",
"> reg_alpha:float=0.0, reg_lambda:float=0.0, random_\n",
"> state:Union[int,numpy.random.mtrand.RandomState,No\n",
"> neType]=None, n_jobs:int=-1,\n",
"> silent:Union[bool,str]='warn',\n",
"> importance_type:str='split', **kwargs)\n",
"\n",
"LightGBM regressor."
],
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### RandomForestWSAA_LGBM\n",
"\n",
"> RandomForestWSAA_LGBM (boosting_type:str='gbdt', num_leaves:int=31,\n",
"> max_depth:int=-1, learning_rate:float=0.1,\n",
"> n_estimators:int=100,\n",
"> subsample_for_bin:int=200000,\n",
"> objective:Union[str,Callable,NoneType]=None,\n",
"> class_weight:Union[Dict,str,NoneType]=None,\n",
"> min_split_gain:float=0.0,\n",
"> min_child_weight:float=0.001,\n",
"> min_child_samples:int=20, subsample:float=1.0,\n",
"> subsample_freq:int=0, colsample_bytree:float=1.0,\n",
"> reg_alpha:float=0.0, reg_lambda:float=0.0, random_\n",
"> state:Union[int,numpy.random.mtrand.RandomState,No\n",
"> neType]=None, n_jobs:int=-1,\n",
"> silent:Union[bool,str]='warn',\n",
"> importance_type:str='split', **kwargs)\n",
"\n",
"LightGBM regressor."
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| echo: false\n",
"#| output: asis\n",
"show_doc(RandomForestWSAA_LGBM)"
]
},
{
"cell_type": "markdown",
"id": "f471bba6-50ec-49a9-980d-1c10f4361938",
Expand All @@ -221,15 +307,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L112){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L203){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### SampleAverageApproximation\n",
"\n",
Expand All @@ -241,7 +327,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L112){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/kaiguender/dddex/blob/main/dddex/wSAA.py#L203){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### SampleAverageApproximation\n",
"\n",
Expand All @@ -251,7 +337,7 @@
"by assigning equal probability to each historical observation of said target variable."
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -325,13 +411,53 @@
"source": [
"# Test Code"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c5f055de-89ce-4943-8242-8f434ee8a3f1",
"metadata": {
"language": "python"
},
"outputs": [],
"source": [
"# #| hide\n",
"\n",
"# from lightgbm import LGBMRegressor\n",
"# import lightgbm as lgb\n",
"# from dddex.loadData import *\n",
"# import ipdb\n",
"# import inspect\n",
"# from sklearn.base import RegressorMixin\n",
"\n",
"# data, XTrain, yTrain, XTest, yTest = loadDataBakery()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbcd2880-1b9e-4d07-ab84-4ba8a2bef226",
"metadata": {
"language": "python"
},
"outputs": [],
"source": [
"# #| hide\n",
"\n",
"# RF = RandomForestWSAA_LGBM(max_depth = 2,\n",
"# n_estimators = 10,\n",
"# n_jobs = 1,\n",
"# boosting_type = 'rf',\n",
"# subsample_freq = 0,\n",
"# subsample = 1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dddex",
"display_name": "python3",
"language": "python",
"name": "dddex"
"name": "python3"
}
},
"nbformat": 4,
Expand Down
Binary file modified dddex/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/_modidx.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/baseClasses.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/crossValidation.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/levelSetKDEx_multivariate.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/levelSetKDEx_univariate.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/loadData.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/utils.cpython-38.pyc
Binary file not shown.
Binary file modified dddex/__pycache__/wSAA.cpython-38.pyc
Binary file not shown.
6 changes: 6 additions & 0 deletions dddex/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@
'dddex.wSAA.RandomForestWSAA.getWeights': ('wsaa.html#randomforestwsaa.getweights', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA.pointPredict': ('wsaa.html#randomforestwsaa.pointpredict', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA.predict': ('wsaa.html#randomforestwsaa.predict', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM': ('wsaa.html#randomforestwsaa_lgbm', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM.fit': ('wsaa.html#randomforestwsaa_lgbm.fit', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM.getWeights': ('wsaa.html#randomforestwsaa_lgbm.getweights', 'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM.pointPredict': ( 'wsaa.html#randomforestwsaa_lgbm.pointpredict',
'dddex/wSAA.py'),
'dddex.wSAA.RandomForestWSAA_LGBM.predict': ('wsaa.html#randomforestwsaa_lgbm.predict', 'dddex/wSAA.py'),
'dddex.wSAA.SampleAverageApproximation': ('wsaa.html#sampleaverageapproximation', 'dddex/wSAA.py'),
'dddex.wSAA.SampleAverageApproximation.__init__': ( 'wsaa.html#sampleaverageapproximation.__init__',
'dddex/wSAA.py'),
Expand Down
93 changes: 92 additions & 1 deletion dddex/wSAA.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import copy

from sklearn.ensemble import RandomForestRegressor
from lightgbm import LGBMRegressor
from sklearn.base import MetaEstimatorMixin
from .baseClasses import BaseWeightsBasedEstimator
from .utils import restructureWeightsDataList

# %% auto 0
__all__ = ['RandomForestWSAA', 'SampleAverageApproximation']
__all__ = ['RandomForestWSAA', 'RandomForestWSAA_LGBM', 'SampleAverageApproximation']

# %% ../nbs/03_wSAA.ipynb 7
class RandomForestWSAA(RandomForestRegressor, BaseWeightsBasedEstimator):
Expand Down Expand Up @@ -109,6 +110,96 @@ def pointPredict(self,


# %% ../nbs/03_wSAA.ipynb 12
class RandomForestWSAA_LGBM(LGBMRegressor, BaseWeightsBasedEstimator):

def fit(self,
X: np.ndarray, # Feature matrix
y: np.ndarray, # Target values
**kwargs):

super().fit(X = X,
y = y,
**kwargs)

self.yTrain = y

self.leafIndicesTrain = self.pointPredict(X, pred_leaf = True)

#---

def getWeights(self,
X: np.ndarray, # Feature matrix for which conditional density estimates are computed.
# Specifies structure of the returned density estimates. One of:
# 'all', 'onlyPositiveWeights', 'summarized', 'cumDistribution', 'cumDistributionSummarized'
outputType: str='onlyPositiveWeights',
# Optional. List with length X.shape[0]. Values are multiplied to the estimated
# density of each sample for scaling purposes.
scalingList: list=None,
) -> list: # List whose elements are the conditional density estimates for the samples specified by `X`.

__doc__ = BaseWeightsBasedEstimator.getWeights.__doc__

#---

leafIndicesDf = self.pointPredict(X, pred_leaf = True)

weightsDataList = list()

for leafIndices in leafIndicesDf:
leafComparisonMatrix = (self.leafIndicesTrain == leafIndices) * 1
nObsInSameLeaf = np.sum(leafComparisonMatrix, axis = 0)

# It can happen that RF decides that the best strategy is to fit no tree at
# all and simply average all results (happens when min_child_sample is too high, for example).
# In this case 'leafComparisonMatrix' mustn't be averaged because there has been only a single tree.
if len(leafComparisonMatrix.shape) == 1:
weights = leafComparisonMatrix / nObsInSameLeaf
else:
weights = np.mean(leafComparisonMatrix / nObsInSameLeaf, axis = 1)

weightsPosIndex = np.where(weights > 0)[0]

weightsDataList.append((weights[weightsPosIndex], weightsPosIndex))

#---

weightsDataList = restructureWeightsDataList(weightsDataList = weightsDataList,
outputType = outputType,
y = self.yTrain,
scalingList = scalingList,
equalWeights = False)

return weightsDataList

#---

def predict(self : BaseWeightsBasedEstimator,
X: np.ndarray, # Feature matrix for which conditional quantiles are computed.
probs: list, # Probabilities for which quantiles are computed.
outputAsDf: bool=True, # Determines output. Either a dataframe with probs as columns or a dict with probs as keys.
# Optional. List with length X.shape[0]. Values are multiplied to the predictions
# of each sample to rescale values.
scalingList: list=None,
):

__doc__ = BaseWeightsBasedEstimator.predict.__doc__

return super(MetaEstimatorMixin, self).predict(X = X,
probs = probs,
scalingList = scalingList)

#---

def pointPredict(self,
X: np.ndarray, # Feature Matrix
**kwargs):
"""Original `predict` method to generate point forecasts"""

return super().predict(X = X,
**kwargs)


# %% ../nbs/03_wSAA.ipynb 14
class SampleAverageApproximation(BaseWeightsBasedEstimator):
"""SAA is a featureless approach that assumes the density of the target variable is given
by assigning equal probability to each historical observation of said target variable."""
Expand Down
Loading

0 comments on commit 15ff269

Please sign in to comment.