Skip to content

Commit

Permalink
add bad states to example to illustrate the state weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanJamesLew committed Apr 13, 2024
1 parent 0cb9ddb commit a8f2f97
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions notebooks/weighted-cost-func.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
"import sys\n",
"sys.path.append(\"..\")\n",
"# this is the convenience function\n",
"from autokoopman import auto_koopman"
"from autokoopman import auto_koopman\n",
"import autokoopman as ak"
]
},
{
Expand All @@ -45,7 +46,12 @@
" initial_states=np.random.uniform(low=-2.0, high=2.0, size=(10, 2)),\n",
" tspan=[0.0, 6.0],\n",
" sampling_period=0.1\n",
")"
")\n",
"\n",
"# add garbage states -- we will weight these values to zero\n",
"training_data = ak.TrajectoriesData({\n",
" key: ak.Trajectory(t.times, np.hstack([t.states, np.random.rand(len(t.states), 3)]), t.inputs) for key, t in training_data._trajs.items()\n",
"})"
]
},
{
Expand Down Expand Up @@ -82,7 +88,8 @@
" #w = np.sum(traj.abs().states, axis=1)\n",
" #w = 1/(traj.abs().states+1.0)\n",
" w = np.ones(traj.states.shape)\n",
" w[:, 0] = 0.2\n",
" w[:, -3:] = 0.0\n",
" w[:, 0] = 1.0\n",
" weights.append(w)\n",
"\n",
" # weight garbage trajectory to zero\n",
Expand Down Expand Up @@ -124,7 +131,7 @@
" learning_weights=weights, # weight the eDMD algorithm objectives\n",
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
" opt=\"grid\", # grid search to find best hyperparameters\n",
" n_obs=10, # maximum number of observables to try\n",
" n_obs=20, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
Expand Down Expand Up @@ -156,7 +163,7 @@
" learning_weights=None, # don't use eDMD weighting\n",
" scoring_weights=weights, # pass weights as required for cost_func=\"weighted\"\n",
" opt=\"grid\", # grid search to find best hyperparameters\n",
" n_obs=10, # maximum number of observables to try\n",
" n_obs=20, # maximum number of observables to try\n",
" max_opt_iter=200, # maximum number of optimization iterations\n",
" grid_param_slices=5, # for grid search, number of slices for each parameter\n",
" n_splits=5, # k-folds validation for tuning, helps stabilize the scoring\n",
Expand Down Expand Up @@ -185,17 +192,18 @@
"# get the model from the experiment results\n",
"model = experiment_results['tuned_model']\n",
"model_uw = experiment_results_unweighted['tuned_model']\n",
"tend = 7.0\n",
"\n",
"# simulate using the learned model\n",
"iv = [0.1, 0.5]\n",
"iv = [0.5, 0.5, 0.0, 0.0, 0.0]\n",
"trajectory = model.solve_ivp(\n",
" initial_state=iv,\n",
" tspan=(0.0, 10.0),\n",
" tspan=(0.0, tend, 0.0),\n",
" sampling_period=0.1\n",
")\n",
"trajectory_uw = model_uw.solve_ivp(\n",
" initial_state=iv,\n",
" tspan=(0.0, 10.0),\n",
" tspan=(0.0, tend, 0.0),\n",
" sampling_period=0.1\n",
")"
]
Expand All @@ -209,17 +217,17 @@
"source": [
"# simulate the ground truth for comparison\n",
"true_trajectory = fhn.solve_ivp(\n",
" initial_state=iv,\n",
" tspan=(0.0, 10.0),\n",
" initial_state=iv[:2],\n",
" tspan=(0.0, tend),\n",
" sampling_period=0.1\n",
")\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"\n",
"# plot the results\n",
"plt.plot(*true_trajectory.states.T, linewidth=2, label='Ground Truth')\n",
"plt.plot(*trajectory.states.T, label='Weighted Trajectory Prediction')\n",
"plt.plot(*trajectory_uw.states.T, label='Trajectory Prediction')\n",
"plt.plot(*true_trajectory.states[:, :2].T, linewidth=2, label='Ground Truth')\n",
"plt.plot(*trajectory.states[:, :2].T, label='Weighted Trajectory Prediction')\n",
"plt.plot(*trajectory_uw.states[:, :2].T, label='Trajectory Prediction')\n",
"\n",
"\n",
"plt.xlabel(\"$x_1$\")\n",
Expand Down

0 comments on commit a8f2f97

Please sign in to comment.