1
+
2
+
3
+ <!DOCTYPE html>
4
+ < html class ="writer-html5 " lang ="en " data-content_root ="../../../ ">
5
+ < head >
6
+ < meta charset ="utf-8 " />
7
+ < meta name ="viewport " content ="width=device-width, initial-scale=1.0 " />
8
+ < title > med_bench.estimation.mediation_coefficient_product — med_bench documentation</ title >
9
+ < link rel ="stylesheet " type ="text/css " href ="../../../_static/pygments.css?v=b86133f3 " />
10
+ < link rel ="stylesheet " type ="text/css " href ="../../../_static/css/theme.css?v=e59714d7 " />
11
+
12
+
13
+ < script src ="../../../_static/jquery.js?v=5d32c60e "> </ script >
14
+ < script src ="../../../_static/_sphinx_javascript_frameworks_compat.js?v=2cd50e6c "> </ script >
15
+ < script src ="../../../_static/documentation_options.js?v=5929fcd5 "> </ script >
16
+ < script src ="../../../_static/doctools.js?v=9bcbadda "> </ script >
17
+ < script src ="../../../_static/sphinx_highlight.js?v=dc90522c "> </ script >
18
+ < script src ="../../../_static/js/theme.js "> </ script >
19
+ < link rel ="index " title ="Index " href ="../../../genindex.html " />
20
+ < link rel ="search " title ="Search " href ="../../../search.html " />
21
+ </ head >
22
+
23
+ < body class ="wy-body-for-nav ">
24
+ < div class ="wy-grid-for-nav ">
25
+ < nav data-toggle ="wy-nav-shift " class ="wy-nav-side ">
26
+ < div class ="wy-side-scroll ">
27
+ < div class ="wy-side-nav-search " >
28
+
29
+
30
+
31
+ < a href ="../../../index.html " class ="icon icon-home ">
32
+ med_bench
33
+ </ a >
34
+ < div role ="search ">
35
+ < form id ="rtd-search-form " class ="wy-form " action ="../../../search.html " method ="get ">
36
+ < input type ="text " name ="q " placeholder ="Search docs " aria-label ="Search docs " />
37
+ < input type ="hidden " name ="check_keywords " value ="yes " />
38
+ < input type ="hidden " name ="area " value ="default " />
39
+ </ form >
40
+ </ div >
41
+ </ div > < div class ="wy-menu wy-menu-vertical " data-spy ="affix " role ="navigation " aria-label ="Navigation menu ">
42
+ < p class ="caption " role ="heading "> < span class ="caption-text "> Contents:</ span > </ p >
43
+ < ul >
44
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../../../modules.html "> Estimation</ a > </ li >
45
+ < li class ="toctree-l1 "> < a class ="reference internal " href ="../../../modules.html#get-simulated-data "> get_simulated_data</ a > </ li >
46
+ </ ul >
47
+
48
+ </ div >
49
+ </ div >
50
+ </ nav >
51
+
52
+ < section data-toggle ="wy-nav-shift " class ="wy-nav-content-wrap "> < nav class ="wy-nav-top " aria-label ="Mobile navigation menu " >
53
+ < i data-toggle ="wy-nav-top " class ="fa fa-bars "> </ i >
54
+ < a href ="../../../index.html "> med_bench</ a >
55
+ </ nav >
56
+
57
+ < div class ="wy-nav-content ">
58
+ < div class ="rst-content ">
59
+ < div role ="navigation " aria-label ="Page navigation ">
60
+ < ul class ="wy-breadcrumbs ">
61
+ < li > < a href ="../../../index.html " class ="icon icon-home " aria-label ="Home "> </ a > </ li >
62
+ < li class ="breadcrumb-item "> < a href ="../../index.html "> Module code</ a > </ li >
63
+ < li class ="breadcrumb-item active "> med_bench.estimation.mediation_coefficient_product</ li >
64
+ < li class ="wy-breadcrumbs-aside ">
65
+ </ li >
66
+ </ ul >
67
+ < hr />
68
+ </ div >
69
+ < div role ="main " class ="document " itemscope ="itemscope " itemtype ="http://schema.org/Article ">
70
+ < div itemprop ="articleBody ">
71
+
72
+ < h1 > Source code for med_bench.estimation.mediation_coefficient_product</ h1 > < div class ="highlight "> < pre >
73
+ < span > </ span > < span class ="kn "> import</ span > < span class ="w "> </ span > < span class ="nn "> numpy</ span > < span class ="w "> </ span > < span class ="k "> as</ span > < span class ="w "> </ span > < span class ="nn "> np</ span >
74
+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> sklearn.linear_model</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> RidgeCV</ span >
75
+
76
+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> med_bench.estimation.base</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> Estimator</ span >
77
+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> med_bench.utils.constants</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> ALPHAS</ span > < span class ="p "> ,</ span > < span class ="n "> CV_FOLDS</ span > < span class ="p "> ,</ span > < span class ="n "> TINY</ span >
78
+ < span class ="kn "> from</ span > < span class ="w "> </ span > < span class ="nn "> med_bench.utils.decorators</ span > < span class ="w "> </ span > < span class ="kn "> import</ span > < span class ="n "> fitted</ span >
79
+
80
+
81
+ < div class ="viewcode-block " id ="CoefficientProduct ">
82
+ < a class ="viewcode-back " href ="../../../modules.html#med_bench.estimation.mediation_coefficient_product.CoefficientProduct "> [docs]</ a >
83
+ < span class ="k "> class</ span > < span class ="w "> </ span > < span class ="nc "> CoefficientProduct</ span > < span class ="p "> (</ span > < span class ="n "> Estimator</ span > < span class ="p "> ):</ span >
84
+ < span class ="w "> </ span > < span class ="sd "> """Coefficient Product estimatation method class"""</ span >
85
+
86
+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> regularize</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="p "> ,</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> ):</ span >
87
+ < span class ="w "> </ span > < span class ="sd "> """Initializes Coefficient product estimatation method</ span >
88
+
89
+ < span class ="sd "> Parameters</ span >
90
+ < span class ="sd "> ----------</ span >
91
+ < span class ="sd "> regularize (bool) : regularization parameter</ span >
92
+ < span class ="sd "> """</ span >
93
+ < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> )</ span >
94
+
95
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _regularize</ span > < span class ="o "> =</ span > < span class ="n "> regularize</ span >
96
+
97
+ < div class ="viewcode-block " id ="CoefficientProduct.fit ">
98
+ < a class ="viewcode-back " href ="../../../modules.html#med_bench.estimation.mediation_coefficient_product.CoefficientProduct.fit "> [docs]</ a >
99
+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> fit</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> m</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> y</ span > < span class ="p "> ):</ span >
100
+ < span class ="w "> </ span > < span class ="sd "> """Fits nuisance parameters to data</ span >
101
+
102
+ < span class ="sd "> Parameters</ span >
103
+ < span class ="sd "> ----------</ span >
104
+ < span class ="sd "> t array-like, shape (n_samples)</ span >
105
+ < span class ="sd "> treatment value for each unit, binary</ span >
106
+
107
+ < span class ="sd "> m array-like, shape (n_samples)</ span >
108
+ < span class ="sd "> mediator value for each unit, here m is necessary binary and uni-</ span >
109
+ < span class ="sd "> dimensional</ span >
110
+
111
+ < span class ="sd "> x array-like, shape (n_samples, n_features_covariates)</ span >
112
+ < span class ="sd "> covariates (potential confounders) values</ span >
113
+
114
+ < span class ="sd "> y array-like, shape (n_samples)</ span >
115
+ < span class ="sd "> outcome value for each unit, continuous</ span >
116
+
117
+ < span class ="sd "> """</ span >
118
+ < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _regularize</ span > < span class ="p "> :</ span >
119
+ < span class ="n "> alphas</ span > < span class ="o "> =</ span > < span class ="n "> ALPHAS</ span >
120
+ < span class ="k "> else</ span > < span class ="p "> :</ span >
121
+ < span class ="n "> alphas</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="n "> TINY</ span > < span class ="p "> ]</ span >
122
+
123
+ < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> m</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> y</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _resize</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> m</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> y</ span > < span class ="p "> )</ span >
124
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _coef_t_m</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> zeros</ span > < span class ="p "> (</ span > < span class ="n "> m</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ])</ span >
125
+ < span class ="k "> for</ span > < span class ="n "> i</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> m</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]):</ span >
126
+ < span class ="n "> m_reg</ span > < span class ="o "> =</ span > < span class ="n "> RidgeCV</ span > < span class ="p "> (</ span > < span class ="n "> alphas</ span > < span class ="o "> =</ span > < span class ="n "> alphas</ span > < span class ="p "> ,</ span > < span class ="n "> cv</ span > < span class ="o "> =</ span > < span class ="n "> CV_FOLDS</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> fit</ span > < span class ="p "> (</ span >
127
+ < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> hstack</ span > < span class ="p "> ((</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="o "> .</ span > < span class ="n "> reshape</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> ))),</ span > < span class ="n "> m</ span > < span class ="p "> [:,</ span > < span class ="n "> i</ span > < span class ="p "> ]</ span >
128
+ < span class ="p "> )</ span >
129
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _coef_t_m</ span > < span class ="p "> [</ span > < span class ="n "> i</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> m_reg</ span > < span class ="o "> .</ span > < span class ="n "> coef_</ span > < span class ="p "> [</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span >
130
+ < span class ="n "> y_reg</ span > < span class ="o "> =</ span > < span class ="n "> RidgeCV</ span > < span class ="p "> (</ span > < span class ="n "> alphas</ span > < span class ="o "> =</ span > < span class ="n "> alphas</ span > < span class ="p "> ,</ span > < span class ="n "> cv</ span > < span class ="o "> =</ span > < span class ="n "> CV_FOLDS</ span > < span class ="p "> )</ span > < span class ="o "> .</ span > < span class ="n "> fit</ span > < span class ="p "> (</ span >
131
+ < span class ="n "> np</ span > < span class ="o "> .</ span > < span class ="n "> hstack</ span > < span class ="p "> ((</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="o "> .</ span > < span class ="n "> reshape</ span > < span class ="p "> (</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 1</ span > < span class ="p "> ),</ span > < span class ="n "> m</ span > < span class ="p "> )),</ span > < span class ="n "> y</ span >
132
+ < span class ="p "> )</ span >
133
+
134
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _coef_y</ span > < span class ="o "> =</ span > < span class ="n "> y_reg</ span > < span class ="o "> .</ span > < span class ="n "> coef_</ span >
135
+
136
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _fitted</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span >
137
+
138
+ < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> verbose</ span > < span class ="p "> :</ span >
139
+ < span class ="nb "> print</ span > < span class ="p "> (</ span > < span class ="s2 "> "Nuisance models fitted"</ span > < span class ="p "> )</ span > </ div >
140
+
141
+
142
+ < span class ="nd "> @fitted</ span >
143
+ < span class ="k "> def</ span > < span class ="w "> </ span > < span class ="nf "> estimate</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> m</ span > < span class ="p "> ,</ span > < span class ="n "> x</ span > < span class ="p "> ,</ span > < span class ="n "> y</ span > < span class ="p "> ):</ span >
144
+ < span class ="w "> </ span > < span class ="sd "> """Estimates causal effect on data"""</ span >
145
+ < span class ="n "> direct_effect_treated</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _coef_y</ span > < span class ="p "> [</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]]</ span >
146
+ < span class ="n "> direct_effect_control</ span > < span class ="o "> =</ span > < span class ="n "> direct_effect_treated</ span >
147
+ < span class ="n "> indirect_effect_treated</ span > < span class ="o "> =</ span > < span class ="nb "> sum</ span > < span class ="p "> (</ span >
148
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _coef_y</ span > < span class ="p "> [</ span > < span class ="n "> x</ span > < span class ="o "> .</ span > < span class ="n "> shape</ span > < span class ="p "> [</ span > < span class ="mi "> 1</ span > < span class ="p "> ]</ span > < span class ="o "> +</ span > < span class ="mi "> 1</ span > < span class ="p "> :]</ span > < span class ="o "> *</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _coef_t_m</ span >
149
+ < span class ="p "> )</ span >
150
+ < span class ="n "> indirect_effect_control</ span > < span class ="o "> =</ span > < span class ="n "> indirect_effect_treated</ span >
151
+
152
+ < span class ="n "> causal_effects</ span > < span class ="o "> =</ span > < span class ="p "> {</ span >
153
+ < span class ="s2 "> "total_effect"</ span > < span class ="p "> :</ span > < span class ="n "> direct_effect_treated</ span > < span class ="o "> +</ span > < span class ="n "> indirect_effect_control</ span > < span class ="p "> ,</ span >
154
+ < span class ="s2 "> "direct_effect_treated"</ span > < span class ="p "> :</ span > < span class ="n "> direct_effect_treated</ span > < span class ="p "> ,</ span >
155
+ < span class ="s2 "> "direct_effect_control"</ span > < span class ="p "> :</ span > < span class ="n "> direct_effect_control</ span > < span class ="p "> ,</ span >
156
+ < span class ="s2 "> "indirect_effect_treated"</ span > < span class ="p "> :</ span > < span class ="n "> indirect_effect_treated</ span > < span class ="p "> ,</ span >
157
+ < span class ="s2 "> "indirect_effect_control"</ span > < span class ="p "> :</ span > < span class ="n "> indirect_effect_control</ span > < span class ="p "> ,</ span >
158
+ < span class ="p "> }</ span >
159
+ < span class ="k "> return</ span > < span class ="n "> causal_effects</ span > </ div >
160
+
161
+ </ pre > </ div >
162
+
163
+ </ div >
164
+ </ div >
165
+ < footer >
166
+
167
+ < hr />
168
+
169
+ < div role ="contentinfo ">
170
+ < p > © Copyright 2025, Judith Abecassis, Houssam Zenati, Bertrand Thirion, Hadrien Mariaccia, Mouad Zbakh, Sami Boumaïza, Julie Josse.</ p >
171
+ </ div >
172
+
173
+ Built with < a href ="https://www.sphinx-doc.org/ "> Sphinx</ a > using a
174
+ < a href ="https://github.com/readthedocs/sphinx_rtd_theme "> theme</ a >
175
+ provided by < a href ="https://readthedocs.org "> Read the Docs</ a > .
176
+
177
+
178
+ </ footer >
179
+ </ div >
180
+ </ div >
181
+ </ section >
182
+ </ div >
183
+ < script >
184
+ jQuery ( function ( ) {
185
+ SphinxRtdTheme . Navigation . enable ( true ) ;
186
+ } ) ;
187
+ </ script >
188
+
189
+ </ body >
190
+ </ html >
0 commit comments