Skip to content

Commit cd3391d

Browse files
committed
reformat jupyter notebooks
1 parent b57904d commit cd3391d

8 files changed

+675
-544
lines changed

examples/notebook_BMTI.ipynb

+24-19
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
],
4545
"source": [
4646
"# Load a 6 dimensional dataset from the dataset folder\n",
47-
"X = np.genfromtxt('datasets/6d_double_well.txt')\n",
48-
"true_log_den = np.genfromtxt('datasets/6d_double_well_logdensities_and_grads.txt')[:, 0]\n",
47+
"X = np.genfromtxt(\"datasets/6d_double_well.txt\")\n",
48+
"true_log_den = np.genfromtxt(\"datasets/6d_double_well_logdensities_and_grads.txt\")[:, 0]\n",
4949
"\n",
5050
"# Subsample the dataset for a faster run\n",
5151
"every = 1\n",
@@ -96,13 +96,13 @@
9696
"source": [
9797
"d = DensityAdvanced(X, maxk=1000, verbose=True)\n",
9898
"\n",
99-
"# copute the density using the kNN method \n",
99+
"# copute the density using the kNN method\n",
100100
"d.compute_density_kNN(k=10)\n",
101101
"log_den_kNN = d.log_den\n",
102102
"\n",
103103
"# Compute the density using the kstarNN method\n",
104104
"d.compute_density_kstarNN()\n",
105-
"log_den_kstarNN = d.log_den \n",
105+
"log_den_kstarNN = d.log_den\n",
106106
"\n",
107107
"# Compute the density using the BMTI method\n",
108108
"d.compute_density_BMTI()\n",
@@ -115,11 +115,11 @@
115115
"metadata": {},
116116
"outputs": [],
117117
"source": [
118-
"# remove the mean to both the true and estimated density \n",
118+
"# remove the mean to both the true and estimated density\n",
119119
"true_log_den = true_log_den - np.mean(true_log_den)\n",
120120
"log_den_kNN = log_den_kNN - np.mean(log_den_kNN)\n",
121121
"log_den_kstarNN = log_den_kstarNN - np.mean(log_den_kstarNN)\n",
122-
"log_den_BMTI = log_den_BMTI - np.mean(log_den_BMTI)\n"
122+
"log_den_BMTI = log_den_BMTI - np.mean(log_den_BMTI)"
123123
]
124124
},
125125
{
@@ -139,13 +139,13 @@
139139
],
140140
"source": [
141141
"# compute MSE errors\n",
142-
"MSE_kNN = np.mean((log_den_kNN - true_log_den)**2)\n",
143-
"MSE_kstarNN = np.mean((log_den_kstarNN - true_log_den)**2)\n",
144-
"MSE_BMTI = np.mean((log_den_BMTI - true_log_den)**2)\n",
142+
"MSE_kNN = np.mean((log_den_kNN - true_log_den) ** 2)\n",
143+
"MSE_kstarNN = np.mean((log_den_kstarNN - true_log_den) ** 2)\n",
144+
"MSE_BMTI = np.mean((log_den_BMTI - true_log_den) ** 2)\n",
145145
"\n",
146-
"print('MSE kNN: ', MSE_kNN)\n",
147-
"print('MSE kstarNN: ', MSE_kstarNN)\n",
148-
"print('MSE BMTI: ', MSE_BMTI)\n"
146+
"print(\"MSE kNN: \", MSE_kNN)\n",
147+
"print(\"MSE kstarNN: \", MSE_kstarNN)\n",
148+
"print(\"MSE BMTI: \", MSE_BMTI)"
149149
]
150150
},
151151
{
@@ -167,14 +167,19 @@
167167
"source": [
168168
"# plot real density vs estimated density\n",
169169
"plt.figure(figsize=(5, 5))\n",
170-
"plt.scatter(true_log_den, log_den_kNN, marker= '.', label='kNN')\n",
171-
"plt.scatter(true_log_den, log_den_kstarNN,marker= '.', label='kstarNN')\n",
172-
"plt.scatter(true_log_den, log_den_BMTI, marker= '.', label='BMTI',)\n",
173-
"plt.plot(true_log_den, true_log_den, 'k--')\n",
174-
"plt.xlabel('True log density')\n",
175-
"plt.ylabel('Estimated log density')\n",
170+
"plt.scatter(true_log_den, log_den_kNN, marker=\".\", label=\"kNN\")\n",
171+
"plt.scatter(true_log_den, log_den_kstarNN, marker=\".\", label=\"kstarNN\")\n",
172+
"plt.scatter(\n",
173+
" true_log_den,\n",
174+
" log_den_BMTI,\n",
175+
" marker=\".\",\n",
176+
" label=\"BMTI\",\n",
177+
")\n",
178+
"plt.plot(true_log_den, true_log_den, \"k--\")\n",
179+
"plt.xlabel(\"True log density\")\n",
180+
"plt.ylabel(\"Estimated log density\")\n",
176181
"plt.legend()\n",
177-
"plt.tight_layout()\n"
182+
"plt.tight_layout()"
178183
]
179184
},
180185
{

examples/notebook_beta_hairpin.ipynb

+64-40
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@
3232
"import numpy as np\n",
3333
"import matplotlib.pyplot as plt\n",
3434
"import seaborn as sns\n",
35+
"\n",
3536
"sns.set_style(\"ticks\")\n",
3637
"sns.set_context(\"notebook\")\n",
3738
"\n",
3839
"from urllib.request import urlretrieve\n",
39-
"data_url_dihedrals =\"https://figshare.com/ndownloader/files/36359700\"\n",
40+
"\n",
41+
"data_url_dihedrals = \"https://figshare.com/ndownloader/files/36359700\"\n",
4042
"path_dihedrals = \"./cln025traj_dihedrals_decimated_equilibrated.npy\"\n",
41-
"data_url_distances =\"https://figshare.com/ndownloader/files/36359697\"\n",
43+
"data_url_distances = \"https://figshare.com/ndownloader/files/36359697\"\n",
4244
"path_distances = \"./cln025traj_distances_decimated_equilibrated.npy\""
4345
]
4446
},
@@ -79,7 +81,7 @@
7981
}
8082
],
8183
"source": [
82-
"#download dihedral representation data from Figshare\n",
84+
"# download dihedral representation data from Figshare\n",
8385
"urlretrieve(data_url_dihedrals, path_dihedrals)"
8486
]
8587
},
@@ -92,9 +94,9 @@
9294
"source": [
9395
"all_dihedrals = np.load(path_dihedrals)\n",
9496
"# dihedral names in order:\n",
95-
"# phi1 phi2 phi3 phi4 phi5 phi6 phi7 phi8 phi9 \n",
96-
"# psi1 psi2 psi3 psi4 psi5 psi6 psi7 psi8 psi9 \n",
97-
"# chi1_1 chi1_2 chi1_3 chi1_5 chi1_6 chi1_8 chi1_9 chi1_10 \n",
97+
"# phi1 phi2 phi3 phi4 phi5 phi6 phi7 phi8 phi9\n",
98+
"# psi1 psi2 psi3 psi4 psi5 psi6 psi7 psi8 psi9\n",
99+
"# chi1_1 chi1_2 chi1_3 chi1_5 chi1_6 chi1_8 chi1_9 chi1_10\n",
98100
"# chi2_1 chi2_2 chi2_3 chi2_5 chi2_9 chi2_10"
99101
]
100102
},
@@ -113,10 +115,26 @@
113115
}
114116
],
115117
"source": [
116-
"# we then select a subset of 15 dihedrals identified as the most informative \n",
118+
"# we then select a subset of 15 dihedrals identified as the most informative\n",
117119
"# using the information imbalance greedy optimisation of (Glielmo et al., PNAS Nexus, 2022)\n",
118120
"# the final dataset is described by only 15 features\n",
119-
"coords_from_information_imbalance = [1, 4, 5, 7, 10, 12, 13, 14, 15, 16, 17, 18, 19, 24, 25]\n",
121+
"coords_from_information_imbalance = [\n",
122+
" 1,\n",
123+
" 4,\n",
124+
" 5,\n",
125+
" 7,\n",
126+
" 10,\n",
127+
" 12,\n",
128+
" 13,\n",
129+
" 14,\n",
130+
" 15,\n",
131+
" 16,\n",
132+
" 17,\n",
133+
" 18,\n",
134+
" 19,\n",
135+
" 24,\n",
136+
" 25,\n",
137+
"]\n",
120138
"selected_dihedrals = all_dihedrals[:, coords_from_information_imbalance]\n",
121139
"\n",
122140
"print(selected_dihedrals.shape)"
@@ -151,9 +169,11 @@
151169
],
152170
"source": [
153171
"# initialise a Data object\n",
154-
"d_dihedrals = Data(selected_dihedrals+np.pi, verbose=False)\n",
172+
"d_dihedrals = Data(selected_dihedrals + np.pi, verbose=False)\n",
155173
"# compute distances by setting the correct period\n",
156-
"d_dihedrals.compute_distances(maxk=min(selected_dihedrals.shape[0]-1, 10000), period=2.*np.pi)\n",
174+
"d_dihedrals.compute_distances(\n",
175+
" maxk=min(selected_dihedrals.shape[0] - 1, 10000), period=2.0 * np.pi\n",
176+
")\n",
157177
"# estimate the intrinsic dimension\n",
158178
"d_dihedrals.compute_id_2NN()"
159179
]
@@ -169,7 +189,9 @@
169189
"source": [
170190
"# ID scaling analysig using two different methods\n",
171191
"ids_2nn, errs_2nn, scales_2nn = d_dihedrals.return_id_scaling_2NN()\n",
172-
"ids_gride, errs_gride, scales_gride = d_dihedrals.return_id_scaling_gride(range_max=1024)"
192+
"ids_gride, errs_gride, scales_gride = d_dihedrals.return_id_scaling_gride(\n",
193+
" range_max=1024\n",
194+
")"
173195
]
174196
},
175197
{
@@ -192,18 +214,18 @@
192214
}
193215
],
194216
"source": [
195-
"col = 'darkorange'\n",
217+
"col = \"darkorange\"\n",
196218
"plt.plot(scales_2nn, ids_2nn, alpha=0.85)\n",
197-
"plt.errorbar(scales_2nn, ids_2nn, errs_2nn, fmt='None')\n",
198-
"plt.scatter(scales_2nn, ids_2nn, edgecolors='k',s=50,label='2nn decimation')\n",
219+
"plt.errorbar(scales_2nn, ids_2nn, errs_2nn, fmt=\"None\")\n",
220+
"plt.scatter(scales_2nn, ids_2nn, edgecolors=\"k\", s=50, label=\"2nn decimation\")\n",
199221
"plt.plot(scales_gride, ids_gride, alpha=0.85, color=col)\n",
200-
"plt.errorbar(scales_gride, ids_gride, errs_gride, fmt='None',color=col)\n",
201-
"plt.scatter(scales_gride, ids_gride, edgecolors='k',color=col,s=50,label='2nn gride')\n",
202-
"plt.xlabel(r'Scale',size=15)\n",
203-
"plt.ylabel('Estimated ID',size=15)\n",
222+
"plt.errorbar(scales_gride, ids_gride, errs_gride, fmt=\"None\", color=col)\n",
223+
"plt.scatter(scales_gride, ids_gride, edgecolors=\"k\", color=col, s=50, label=\"2nn gride\")\n",
224+
"plt.xlabel(r\"Scale\", size=15)\n",
225+
"plt.ylabel(\"Estimated ID\", size=15)\n",
204226
"plt.xticks(size=15)\n",
205227
"plt.yticks(size=15)\n",
206-
"plt.legend(frameon=False,fontsize=14)\n",
228+
"plt.legend(frameon=False, fontsize=14)\n",
207229
"plt.tight_layout()"
208230
]
209231
},
@@ -227,7 +249,7 @@
227249
],
228250
"source": [
229251
"# estimate density via PAk\n",
230-
"d_dihedrals.set_id(7.)\n",
252+
"d_dihedrals.set_id(7.0)\n",
231253
"d_dihedrals.compute_density_PAk()"
232254
]
233255
},
@@ -250,7 +272,7 @@
250272
],
251273
"source": [
252274
"# cluster data via Advanced Density Peak\n",
253-
"d_dihedrals.compute_clustering_ADP(Z=4.5,halo=False);\n",
275+
"d_dihedrals.compute_clustering_ADP(Z=4.5, halo=False)\n",
254276
"n_clusters = len(d_dihedrals.cluster_centers)\n",
255277
"print(n_clusters)"
256278
]
@@ -275,7 +297,7 @@
275297
}
276298
],
277299
"source": [
278-
"pl.get_dendrogram(d_dihedrals, cmap='Set2', logscale=False)"
300+
"pl.get_dendrogram(d_dihedrals, cmap=\"Set2\", logscale=False)"
279301
]
280302
},
281303
{
@@ -299,7 +321,7 @@
299321
],
300322
"source": [
301323
"# Cluster populations\n",
302-
"populations = [ len(el) for r_,el in enumerate(d_dihedrals.cluster_indices)]\n",
324+
"populations = [len(el) for r_, el in enumerate(d_dihedrals.cluster_indices)]\n",
303325
"populations"
304326
]
305327
},
@@ -420,8 +442,8 @@
420442
}
421443
],
422444
"source": [
423-
"d_distances = Data(heavy_atom_distances,verbose=False)\n",
424-
"d_distances.compute_distances(maxk=min(heavy_atom_distances.shape[0]-1,10000))\n",
445+
"d_distances = Data(heavy_atom_distances, verbose=False)\n",
446+
"d_distances.compute_distances(maxk=min(heavy_atom_distances.shape[0] - 1, 10000))\n",
425447
"d_distances.compute_id_2NN()"
426448
]
427449
},
@@ -434,7 +456,9 @@
434456
"source": [
435457
"# ID scaling analysig using two different methods\n",
436458
"ids_2nn, errs_2nn, scales_2nn = d_distances.return_id_scaling_2NN()\n",
437-
"ids_gride, errs_gride, scales_gride = d_distances.return_id_scaling_gride(range_max=1024)"
459+
"ids_gride, errs_gride, scales_gride = d_distances.return_id_scaling_gride(\n",
460+
" range_max=1024\n",
461+
")"
438462
]
439463
},
440464
{
@@ -457,18 +481,18 @@
457481
}
458482
],
459483
"source": [
460-
"col = 'darkorange'\n",
484+
"col = \"darkorange\"\n",
461485
"plt.plot(scales_2nn, ids_2nn, alpha=0.85)\n",
462-
"plt.errorbar(scales_2nn, ids_2nn, errs_2nn, fmt='None')\n",
463-
"plt.scatter(scales_2nn, ids_2nn, edgecolors='k',s=50,label='2nn decimation')\n",
486+
"plt.errorbar(scales_2nn, ids_2nn, errs_2nn, fmt=\"None\")\n",
487+
"plt.scatter(scales_2nn, ids_2nn, edgecolors=\"k\", s=50, label=\"2nn decimation\")\n",
464488
"plt.plot(scales_gride, ids_gride, alpha=0.85, color=col)\n",
465-
"plt.errorbar(scales_gride, ids_gride, errs_gride, fmt='None',color=col)\n",
466-
"plt.scatter(scales_gride, ids_gride, edgecolors='k',color=col,s=50,label='2nn gride')\n",
467-
"plt.xlabel(r'Scale',size=15)\n",
468-
"plt.ylabel('Estimated ID',size=15)\n",
489+
"plt.errorbar(scales_gride, ids_gride, errs_gride, fmt=\"None\", color=col)\n",
490+
"plt.scatter(scales_gride, ids_gride, edgecolors=\"k\", color=col, s=50, label=\"2nn gride\")\n",
491+
"plt.xlabel(r\"Scale\", size=15)\n",
492+
"plt.ylabel(\"Estimated ID\", size=15)\n",
469493
"plt.xticks(size=15)\n",
470494
"plt.yticks(size=15)\n",
471-
"plt.legend(frameon=False,fontsize=14)\n",
495+
"plt.legend(frameon=False, fontsize=14)\n",
472496
"plt.tight_layout()"
473497
]
474498
},
@@ -489,10 +513,10 @@
489513
],
490514
"source": [
491515
"# estimate density via PAk\n",
492-
"d_distances.set_id(9.)\n",
516+
"d_distances.set_id(9.0)\n",
493517
"d_distances.compute_density_PAk()\n",
494518
"# cluster data via Advanced Density Peak\n",
495-
"d_distances.compute_clustering_ADP(Z=3.5,halo=False);\n",
519+
"d_distances.compute_clustering_ADP(Z=3.5, halo=False)\n",
496520
"n_clusters = len(d_dihedrals.cluster_centers)\n",
497521
"print(n_clusters)"
498522
]
@@ -519,7 +543,7 @@
519543
}
520544
],
521545
"source": [
522-
"pl.get_dendrogram(d_distances, cmap='Set2', logscale=False)"
546+
"pl.get_dendrogram(d_distances, cmap=\"Set2\", logscale=False)"
523547
]
524548
},
525549
{
@@ -541,7 +565,7 @@
541565
],
542566
"source": [
543567
"# Cluster populations\n",
544-
"populations = [ len(el) for r_,el in enumerate(d_distances.cluster_indices)]\n",
568+
"populations = [len(el) for r_, el in enumerate(d_distances.cluster_indices)]\n",
545569
"populations"
546570
]
547571
},
@@ -606,7 +630,7 @@
606630
],
607631
"source": [
608632
"# number of elements in common before permutation\n",
609-
"sum(d_distances.cluster_assignment == d_dihedrals.cluster_assignment)/d_dihedrals.N"
633+
"sum(d_distances.cluster_assignment == d_dihedrals.cluster_assignment) / d_dihedrals.N"
610634
]
611635
},
612636
{
@@ -645,7 +669,7 @@
645669
],
646670
"source": [
647671
"# number of elements in common after permutation\n",
648-
"sum(distances_cluster_assignments_2 == d_dihedrals.cluster_assignment)/d_dihedrals.N"
672+
"sum(distances_cluster_assignments_2 == d_dihedrals.cluster_assignment) / d_dihedrals.N"
649673
]
650674
},
651675
{

0 commit comments

Comments
 (0)