diff --git a/examples/dynamic_gnn_planner.ipynb b/examples/dynamic_gnn_planner.ipynb index 71f586c..f7931c1 100644 --- a/examples/dynamic_gnn_planner.ipynb +++ b/examples/dynamic_gnn_planner.ipynb @@ -67,7 +67,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "waypoints = np.linspace(0, np.pi/4, 40).tolist()\n", @@ -79,18 +84,41 @@ "env = Simple2ArmEnv(objects=[objs])\n", "num_samples = 1000\n", "max_num_samples = 2000" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": { - "collapsed": false, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading dynamic object\n", + "Loading robot from ../data/robot/simple2arm/2dof.urdf\n", + "Robot loaded with item_id 0\n", + "Loading robot from ../data/robot/simple2arm/2dof.urdf\n", + "Robot loaded with item_id 1\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAevUlEQVR4nO3df4wc5Z3n8fe3q7tn2uP57ZmxYxuMwSEgAgmMsiBWp1zY5IBbhfyRzSa72licJd8f3F1yWWlD7v7InXR/JNJps4nuxC0K2TWrXFiWTQ6EuOSIQ7LaKJiYQAjhhz0BG3v8YwbbM2PPD/d09/f+6KqmZzz29NjdU13Tn5dU6qqnq6efooaPn3nqqafM3RERkeRIxV0BERFZGQW3iEjCKLhFRBJGwS0ikjAKbhGRhFFwi4gkTEOC28zuNrM3zWzEzB5sxHeIiLQqq/c4bjMLgAPAx4GjwC+Bz7n7a3X9IhGRFtWIFvdHgBF3f8vd88BjwH0N+B4RkZaUbsDP3Awcqdo+CvzepT6wYcMG37ZtWwOqInJx7k4+nwfAzC66TxAEpNON+F9F5OIOHTrEu+++u+QvZmy/jWa2G9gNcNVVV7F///64qiItqlAocOzYMQBSqdSC8HZ33J1SqUR/fz/r16+Pq5rSooaHhy/6XiO6SkaBrVXbW8KyBdz9YXcfdvfhgYGBBlRDZHlBEBAEAalUasESlQdBcNHWuEhcGtHi/iWww8yuoRzYnwX+pAHfI3JFzIxUKlVZXxzQ0YV7Bbc0m7oHt7sXzOzfAT8CAuA77v7ben+PSD1UB/fF3ldwS7NpSB+3uz8DPNOIny1ST1FwL8XMcHcFtzQdXSqXlqZQliRScEtLq6VVrXCXZqPglpa1+ILkUhcnFdrSjDTJlLS8pUaUROUizUgtbmlpUVdJtF5NLW5pVgpuaWnL9W0rvKUZqatEpMrikFZoSzNSi1sElrxIWe8pj0XqRcEtLU3DACWJFNzSEO7O/Pw87s7k5CRtbW10dXUpDEXqQMEtdVcqlThy5Aj79u2jWCwyOztLJpPh+uuv5+abb6atra0pA1xdJJIUCm6pK3dnZGSEffv2USqVMDOCIKBUKvHaa69x4sQJ7rjjDjZs2NA04X2pC5IKcWlGGlUidVUqlXj77bdJpVKk02nS6TSZTKayPjExwfPPP8+5c+firmpNmuUfF5FqCm6pq7m5Oc6dO0dbWxu5XK6ytLW1kc1myWQyTE5Osn//fkqlUtzVFUkkdZVIXU1NTZFOp2lra6s8TSZ6BFixWKRQKJDP5zl16hTnzp2jq6sr7iqLJI6CW+qqt7eX7u7uSt929QW/UqlEoVBgfn6eubk5JicnFdwil0HBLXWVzWbp6+tjbm6u8ixHuDC4gyDgzJkzbN26dZmfKCKLKbilrsyM/v5+3n33XbLZbOXRX+5e6SqZn58nlUo1RWu7epKppURPey8Wi0D54cIicVNwS12ZGRs3bgTKFyrn5uaYnp5mdnaW3t5ecrkc6XSaIAgYHByMpY6lUolSqVS5SWh2dpaOjg4Azp8/z+zsLOPj40xNTTE1NUUQBLg71157LTfddJNGmkjsFNxSd0EQsHHjRg4ePMgvfvGLyuiRVCrFtddey5YtW1i/fj3t7e2rXrepqSleeOEFzpw5A0ChUGBubo5cLoeZkc/nyefzC6Z6jbp7Dh8+zE033bTqdRZZTMEtDREEAe3t7QuG/JVKJUZGRhgaGmL79u2XfFBvI7g7hw4d4siRIwvKAGZmZiot6ag7JHrAQlQetdTVXSJxW/b/HDP7jpmNmdmrVWV9ZvasmR0MX3vDcjOzb5nZiJm9Yma3NrLy0tzS6fSSDycoFAqxhd+ZM2cqNwMFQbDkTUJReXRxNVqmpqaYnZ2Npd4i1Wpp8vwtcPeisgeBve6+A9gbbgPcA+wIl93AQ/WppiRRX18f2Ww27mosEHV9pFIpgiCoLNUBvXipbnnrFnhpBst2lbj7P5nZtkXF9wEfDdf3AD8FvhyWP+rl3+7nzazHzDa5+/G61VgS42LPcozT4u6P6teLlcN7sx2KNIPL7eMeqgrjE8BQuL4ZOFK139GwTMEtTSG6qzMK6cUtalj62ZNRcE9MTNDZ2RlH1UUqrvjipLu7ma3470cz2025O4WrrrrqSqshCRIFYRyt8UwmU7kdvzq4owulS80MGN08FAQB+Xx+1essstjlXtY/aWabAMLXsbB8FKi+FW5LWHYBd3/Y3YfdfXhgYOAyqyHNLJPJLNk6PX48vj/AouCuXtrb2xe8RhNiVU+Mlc1mm66/XlrX5Qb3U8DOcH0n8GRV+efD0SW3A5Pq325dQRCQy+UuKC8UCjHUpqy3t5dMJrMgwKOQXrxUh3a0f7P12UtrWrarxMy+R/lC5AYzOwp8Ffga8LiZ7QIOA58Jd38GuBcYAWaA+xtQZ5HLYma0t7eTyWQqI0kWjyiJRGO2q5f29vZVH3suspRaRpV87iJv3bXEvg48cKWVEmmUXC5XmUOleijg7Owsp0+fxt3p6empdPEUi0WKxSKlUolsNqubb6Qp6M5JaSlRF4iZVW6yOXz4MO+8807lLs/Dhw/T39/PddddRy6Xo1AoVAJcwS3NQH/3SUMNDQ1dUDY7OxvbmOh0Ok1HR0flbsnJyUmOHj16wdN4Tp06xUsvvcTY2NiCvu6l+uxFVpuCWxrGzMjlcgtuLU+lUkxPT8c2rM7M6O7uJp1O4+4cPHjwohdL8/k8Bw8e5Pz58wRBQFtbW2UWQZE4qatEGiqabCq6XTwaDx3X6Awzo6uri7Nnz5LP5zl//vwl95+fn8fdGRwcrIxGEYmbglsaKpPJ0NfXV5nTen5+nmKxGOuwumw2y+bNmzl79mxN9Whvb6e7u3sVaiZSGwW3NFRHRwcbNmyoPDQ4usgX57A6M6Otra0yUkRzkEjSKLilobq7uxe0VqMx0c1wI0t7ezs9PT1MT0/HXRWRFVFwS0NFFygjUXDn83na2tpirJlIcim4paHMbMEFvWiCKY2HFrl8Cm5puOqQjnNmwMsRBEFTPI1epJqCWxoqurW8eopUuHDO62YVBIHGbkvTUXBLQ0W3lkc3uVzsYQUiUjvdOSmrovopM2ZGoVDQ8xtFLpOCWxpOrWuR+lJXiawKhbdI/ajFLQ2XTl/YPlCQi1w+tbiloRY/QV1Erpxa3LIqSqXSgouRi+e/FpHaqcUtDePuTE5OcuDAAUZHRxkYGGDjxo309/fr2Y0iV0DBLQ1z9uxZfvzjHzMxMYG7c+rUKUZGRtiyZQs33ngjHR0d6kIRuQwKbmmIUqnEr371K6anpysXJ6O+7mPHjjE2Nsb27dv54Ac/yLp16xTgIiugv1elIYrFIpOTk5VnO0ZL9Agzd+fAgQP87Gc/i+0xZiJJtWxwm9lWM3vOzF4zs9+a2RfC8j4ze9bMDoavvWG5mdm3zGzEzF4xs1sbfRDSfFKpFJ2dnXR0dNDZ2UlXVxednZ10dnayfv161q1bR1tbGxMTExw7dkx3UYqsQC0t7gLw5+5+I3A78ICZ3Qg8COx19x3A3nAb4B5gR7jsBh6qe62l6aVSKbZt21Z5kEJPTw89PT10dXXR1dVFT08PfX19dHR0cPLkybirK5Ioy/Zxu/tx4Hi4ftbMXgc2A/cBHw132wP8FPhyWP6ol5tQz5tZj5ltCn+OtAgz4+qrr8bMOH36dKUPu7plXSwWyWaz9PT0xFRLkWRaUR+3mW0DPgzsA4aqwvgEMBSubwaOVH3saFi2+GftNrP9ZrZ/fHx8pfWWBAiCgKuvvprt27czPT1NoVDgzJkznDx5kpGREcyMzs5O3ve+98X61PdUKlVZgiAgCILKti6aSjOqeVSJma0H/hH4ortPVf9Cu7ub2Yo6Kd39YeBhgOHhYXVwrlFRX/exY8eYnJxc8F4QBFx33XV0dnbGVLtyHdLp9IKANrPKAx+Wul1fJG41/VaaWYZyaH/X3b8fFp+MukDMbBMwFpaPAlurPr4lLBNZYGBggM7OzlhbtdXBvagxAnBBqIs0g1pGlRjwCPC6u/9l1VtPATvD9Z3Ak1Xlnw9Hl9wOTKp/W5YyNTUVeyhmMhlyuRwdHR0XLLlcTg80lqZUS4v7TuDPgN+Y2cth2X8CvgY8bma7gMPAZ8L3ngHuBUaAGeD+elZYkieVStHT03NBV8ni7Th0dHQwPz+/ZH92qVSq9IGLNJNaRpX8M3CxZtFdS+zvwANXWC9ZQ1KpVNM+tzGXy1EoFJacwbBUKukCpTQlXXmRltbW1nbR4I6eRq/glmaj4JaWls1mKRaLCx5iXP1EegW3NCMFt8SmWCxWwjEumUxmQXBXU3BLs9JVF1kVQ0NDF5SdPHmSYrEYQ23eEwTBgsmvqpdoYiyRZqMWt6yK9vb2C8qa4Sk40Z2SF6MWtzQjBbe0tOjW9otRaEszUnBLS1vuYcYKbmlGCm5ZFWa2YN4Pd2+KG1uWu8FGwS3NSMEtq6Kvr4+enh7y+XxlAqdisci5c+dindZ1uT5sBbc0IwW3rIpMJsPAwADz8/NAucUd3ZkYt+WCW+EtzUbBLavCzOjr66NQKAALZ9+Lk0JZkkjBLavCzFi/fv0F47YvNRRvtairRJJGwS2rwswqEzpB89yVGPf3i1wOBbesCjNbcm7rZni6u8JbkkbBLatmqaF3cd/yvhyFujQjBbesmqXuUlQwiqycgltW1eKgjju4m6GfXWSl4h9EKy2jGUaQLNaMdRJZjlrcsirMjCAILjnvdRx1UmtbkkgtbolFFJjR8EARqZ1a3LKqqlu4SWjtJqGO0nqWDW4zawf+CWgL93/C3b9qZtcAjwH9wIvAn7l73szagEeB24BTwB+7+6EG1V8SIppYKp/PA+W5S+IKxXw+zxtvvMHU1BS9vb1s3rx5ybrMz89z8uRJTp8+TSqV4pZbbmH9+vUx1FhkoVpa3OeBj7n7OTPLAP9sZv8X+BLwDXd/zMz+F7ALeCh8PePu15nZZ4GvA3/coPpLArg7b7/9NgcOHGB8fByA3t5eurq6KndT3nbbbWSz2VWpz6FDh9i3bx/uThAEbN68mfe///3kcjkAZmdnOXDgAKdPn2Z6errSB79hwwauv/76VamjyKUsG9xevrXtXLiZCRcHPgb8SVi+B/gvlIP7vnAd4Angf5iZeTPcIiexmJyc5IUXXmBmZqYSgmNjY4yNjVXunDx//jx33nlnw5/x6O6Mj49XWtjFYpF33nmH0dHRSlk05SwsvGlobm6uoXUTqVVNFyfNLDCzl4Ex4Fngd8CEu0dXlo4Cm8P1zcARgPD9ScrdKYt/5m4z229m+6NWmKxNR44coVAokMlkyGazldfqZXR0lImJiVWpz9mzZwmCoPJQ4GhIYKlUqvxDUv1+tM/U1FRT3KIvUtPFSXcvAh8ysx7gB8AHrvSL3f1h4GGA4eFh/d+whgVBsOTDgqsnmkqlUqty+/v58+eZmZkhnU5XQvhiYRy1wM0Md2dmZqbh9ROpxYpGlbj7hJk9B9wB9JhZOmxVbwFGw91Gga3AUTNLA92UL1JKixoaGuLEiRPAeyEYPUghatkGQUBXV1fD61IsFikUCks+wCEK8ItdNG2Ghz6IQG2jSgaA+TC0c8DHKV9wfA74NOWRJTuBJ8OPPBVu/yJ8/yfq325tfX193HDDDRw/fpxSqQSUA7RUKjEzM0NfXx+Dg4OVi4ONFgRBpR61iJ6PGfdDH0QitfwmbgL2mFlAuU/8cXd/2sxeAx4zs/8GvAQ8Eu7/CPB3ZjYCnAY+24B6S4KYGRs3bqS3t5cXX3yRycnJSpeFu3PbbbfR3d29KsMDU6kU69atqzxCbam7J6u7UKIlCAK6u7sbXj+RWtQyquQV4MNLlL8FfGSJ8jngj+pSO1kzovm48/k8586VByllMhkymcxFb4VvhLa2Nvr7+5mcnKz0rS8V3NVL1KWzYcOGVamjyHL0t5+sqv7+/gVdDqlUiunpadrb20mn0w0PcDOjo6ODfD5fCe3otXo4YKlUWrCkUqlV/QdG5FIU3NJw586d45VXXqFUKpHP5xeM1Z6YmODnP/857s4111zDrbfe2vBw7O/vZ3Z2liAICIKAVCpVGeoXdYusX7+eUqlEsVikWCxe9Ak+InFQcEvDBUHAsWPHmJ6eXvL9KKhPnjzZ8JkCzYz+/v7KTTanT5/mzTffZGJionLBMpVK0dXVRVdXF4ODg/T29gKsyqgXkVoouKXhMpkM7e3tnD9//oL3qkN6tUZtZDIZNm3ahLszNTVVuSGnem7u6elppqenMTNuvvlmTQErTUXBLQ1nZqxbt45isXjBzS7VN+Cs9gRO0Rzhl7rNPupKUWhLM1FwS8OZGYODg7S3tzM/P0+pVMLMKjMERgE6ODi46gGZzWbp7Oys9G9X1zm6kCnSbBTc0nCpVIobbrgBd2d2dpZisUgqlSKXyy0I6jhGbWSzWXp6ehTckigKblkVUf9xZ2dnrPWYn5/n1KlTnDlzhqmpqUoXTfWdlNVzqOTzeV544QWy2SxDQ0N0d3ezbt06dZ1IrBTc0jLcnV//+teVoYnuTi6Xq3TjLG5xFwoFTp06xcTERCXIBwcHufvuuxs+/azIpSi4paXMzc1V+tSh/DSco0ePXvIz1aNdZmZmVjTPiUgjKLilpWzatIm33noLuPh0rkuJukY0Q6A0AwW3tIzqESyLQ3upEF/cj139NByROCm4paWkUilSqdSCoF7ubs3qBypoaldpBvotlJYSBEHlZqDqR5Vd6ik41S31xUMYReKg4JaWkk6nyeVylVn/Fs+9HT2hBxaGdtRN0t7eruCW2Cm4paW0t7eTy+VqanFXz08Stbg3bNiwYE4TkTgouKWldHd3c/PNNzM5Ocns7CzT09OVmQLz+XylxW1mldZ1R0cH2WyW7u5uurq61OKW2Cm4paWkUim6u7vp7u5ecJt7FNyRaP7t6q4SkWah4JaWVR3IqVRKd0NKYmhQqohIwii4RUQSpubgNrPAzF4ys6fD7WvMbJ+ZjZjZ35tZNixvC7dHwve3NajuIiItaSUt7i8Ar1dtfx34hrtfB5wBdoXlu4AzYfk3wv1ERKROagpuM9sC/Gvg2+G2AR8Dngh32QN8Kly/L9wmfP8u0yV5EZG6qbXF/VfAXwDRfJb9wIS7F8Lto8DmcH0zcAQgfH8y3H8BM9ttZvvNbP/4+Pjl1V5EpAUtG9xm9ofAmLu/WM8vdveH3X3Y3YcHBgbq+aNFRNa0WsZx3wl80szuBdqBLuCbQI+ZpcNW9RZgNNx/FNgKHDWzNNANnKp7zUVEWtSyLW53/4q7b3H3bcBngZ+4+58CzwGfDnfbCTwZrj8VbhO+/xNfyYz1IiJySVcyjvvLwJfMbIRyH/YjYfkjQH9Y/iXgwSurooiIVFvRLe/u/lPgp+H6W8BHlthnDvijOtRNRESWoDsnRUQSRsEtIpIwCm4RkYRRcIuIJIyCW0QkYRTcIiIJo+AWEUkYBbeISMIouEVEEkbBLSKSMApuEZGEUXCLiCSMgltEJGEU3CIiCaPgFhFJGAW3iEjCKLhFRBJGwS0ikjAKbhGRhFFwi4gkjIJbRCRhagpuMztkZr8xs5fNbH9Y1mdmz5rZwfC1Nyw3M/uWmY2Y2StmdmsjD0BEpNWspMX9L939Q+4+HG4/COx19x3A3nAb4B5gR7jsBh6qV2VFROTKukruA/aE63uAT1WVP+plzwM9ZrbpCr5HRESq1BrcDvw/M3vRzHaHZUPufjxcPwEMheubgSNVnz0alomISB2ka9zv99191MwGgWfN7I3qN93dzcxX8sXhPwC7Aa666qqVfFREpKXV1OJ299HwdQz4AfAR4GTUBRK+joW7jwJbqz6+JSxb/DMfdvdhdx8eGBi4/CMQEWkxywa3mXWYWWe0DnwCeBV4CtgZ7rYTeDJcfwr4fDi65HZgsqpLRURErlAtXSVDwA/MLNr/f7v7D83sl8DjZrYLOAx8Jtz/GeBeYASYAe6ve61FRFrYssHt7m8BtyxRfgq4a4lyBx6oS+1EROQCunNSRCRhFNwiIgmj4BYRSRgFt4hIwii4RUQSRsEtIpIwCm4RkYRRcIuIJIyCW0QkYRTcIiIJo+AWEUkYBbeISMIouEVEEkbBLSKSMApuEZGEUXCLiCSMgltEJGEU3CIiCaPgFhFJGAW3iEjCKLhFRBKmpuA2sx4ze8LM3jCz183sDjPrM7Nnzexg+Nob7mtm9i0zGzGzV8zs1sYegohIa6m1xf1N4Ifu/gHgFuB14EFgr7vvAPaG2wD3ADvCZTfwUF1rLCLS4pYNbjPrBv4F8AiAu+fdfQK4D9gT7rYH+FS4fh/wqJc9D/SY2aY611tEpGXV0uK+BhgH/sbMXjKzb5tZBzDk7sfDfU4AQ+H6ZuBI1eePhmUiIlIHtQR3GrgVeMjdPwxM8163CADu7oCv5IvNbLeZ7Tez/ePj4yv5qIhIS6sluI8CR919X7j9BOUgPxl1gYSvY+H7o8DWqs9vCcsWcPeH3X3Y3YcHBgYut/4iIi1n2eB29xPAETO7Piy6C3gNeArYGZbtBJ4M158CPh+OLrkdmKzqUhERkSuUrnG/fw9818yywFvA/ZRD/3Ez2wUcBj4T7vsMcC8wAsyE+4qISJ3UFNzu/jIwvMRbdy2xrwMPXFm1RETkYnTnpIhIwii4RUQSRsEtIpIwCm4RkYRRcIuIJIyCW0QkYRTcIiIJo+AWEUkYBbeISMIouEVEEkbBLSKSMApuEZGEUXCLiCSMgltEJGEU3CIiCaPgFhFJGAW3iEjCKLhFRBJGwS0ikjAKbhGRhFFwi4gkjIJbRCRhlg1uM7vezF6uWqbM7Itm1mdmz5rZwfC1N9zfzOxbZjZiZq+Y2a2NPwwRkdaxbHC7+5vu/iF3/xBwGzAD/AB4ENjr7juAveE2wD3AjnDZDTzUgHqLiLSslXaV3AX8zt0PA/cBe8LyPcCnwvX7gEe97Hmgx8w21aOyIiKy8uD+LPC9cH3I3Y+H6yeAoXB9M3Ck6jNHw7IFzGy3me03s/3j4+MrrIaISOuqObjNLAt8EviHxe+5uwO+ki9294fdfdjdhwcGBlbyURGRlraSFvc9wK/c/WS4fTLqAglfx8LyUWBr1ee2hGUiIlIHKwnuz/FeNwnAU8DOcH0n8GRV+efD0SW3A5NVXSoiInKF0rXsZGYdwMeBf1tV/DXgcTPbBRwGPhOWPwPcC4xQHoFyf91qKyIitQW3u08D/YvKTlEeZbJ4XwceqEvtRETkAlbO2ZgrYXYWeDPueqyyDcC7cVdiFel4175WO+ZGH+/V7r7kyI2aWtyr4E13H467EqvJzPa30jHreNe+VjvmOI9Xc5WIiCSMgltEJGGaJbgfjrsCMWi1Y9bxrn2tdsyxHW9TXJwUEZHaNUuLW0REahR7cJvZ3Wb2Zjh/94PLf6L5mdlWM3vOzF4zs9+a2RfC8jU9h7mZBWb2kpk9HW5fY2b7wuP6+3C+G8ysLdweCd/fFmvFL5OZ9ZjZE2b2hpm9bmZ3rOVzbGb/Mfx9ftXMvmdm7WvtHJvZd8xszMxerSpb8Tk1s53h/gfNbOdS33UlYg1uMwuA/0l5HpQbgc+Z2Y1x1qlOCsCfu/uNwO3AA+FxrfU5zL8AvF61/XXgG+5+HXAG2BWW7wLOhOXfCPdLom8CP3T3DwC3UD72NXmOzWwz8B+AYXe/CQgozxa61s7x3wJ3Lypb0Tk1sz7gq8DvAR8BvhqFfd24e2wLcAfwo6rtrwBfibNODTrOJylPGfAmsCks20R5/DrAXwOfq9q/sl9SFsqTie0FPgY8DRjlmxPSi8818CPgjnA9He5ncR/DCo+3G3h7cb3X6jnmvema+8Jz9jTwr9biOQa2Aa9e7jmlPK/TX1eVL9ivHkvcXSU1zd2dZOGfiB8G9nGFc5g3ub8C/gIohdv9wIS7F8Lt6mOqHG/4/iSLplRIgGuAceBvwu6hb4dz+qzJc+zuo8B/B94BjlM+Zy+yts9xZKXntOHnOu7gXtPMbD3wj8AX3X2q+j0v/1O8Job0mNkfAmPu/mLcdVlFaeBW4CF3/zAwzXt/QgNr7hz3Un661TXA+4AOLuxSWPOa5ZzGHdxrdu5uM8tQDu3vuvv3w+K1Oof5ncAnzewQ8Bjl7pJvUn5sXTStQvUxVY43fL8bOLWaFa6Do8BRd98Xbj9BOcjX6jn+A+Btdx9393ng+5TP+1o+x5GVntOGn+u4g/uXwI7wynSW8sWOp2Ku0xUzMwMeAV5397+semtNzmHu7l9x9y3uvo3yOfyJu/8p8Bzw6XC3xccb/Xf4dLh/7K2YlXD3E8ARM7s+LLoLeI01eo4pd5Hcbmbrwt/v6HjX7DmustJz+iPgE2bWG/6l8omwrH6a4ELAvcAB4HfAf467PnU6pt+n/OfUK8DL4XIv5T6+vcBB4MdAX7i/UR5d8zvgN5Sv3Md+HJd57B8Fng7XtwMvUJ6b/R+AtrC8PdweCd/fHne9L/NYPwTsD8/z/wF61/I5Bv4r8AbwKvB3QNtaO8eUHxZzHJin/FfVrss5p8C/CY99BLi/3vXUnZMiIgkTd1eJiIiskIJbRCRhFNwiIgmj4BYRSRgFt4hIwii4RUQSRsEtIpIwCm4RkYT5/3HyS0Wqx4vsAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "# visualize environment\n", "%matplotlib inline\n", @@ -98,34 +126,96 @@ "env.load()\n", "plt.imshow(env.render())\n", "plt.show()" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": { - "collapsed": false, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "from planner.learned.GNN_dynamic_planner import GNNDynamicPlanner" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, + "metadata": {}, "outputs": [], "source": [ - "from planner.learned.GNN_dynamic_planner import GNNDynamicPlanner" - ], + "weights_path = ['data/weights/dynamic/2arms/weights_gnn.pt','data/weights/dynamic/2arms/weights_head.pt']" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": { - "collapsed": false, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ruipeng/Desktop/lemp/planner/learned/model/base_models.py:77: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", + " freqs = min_freq ** (2 * (torch.arange(self.embed_size) // 2) / self.embed_size)\n", + "/Users/ruipeng/Desktop/lemp/planner/learned/GNN_dynamic_planner.py:93: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/distiller/project/pytorch/torch/csrc/utils/tensor_new.cpp:210.)\n", + " data.v = torch.FloatTensor(points)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "success {'solution': [array([0., 0.], dtype=float32), array([0.052411 , 0.00481347], dtype=float32), array([0.10482201, 0.00962694], dtype=float32), array([0.15723301, 0.01444041], dtype=float32), array([0.20964402, 0.01925389], dtype=float32), array([0.26205504, 0.02406736], dtype=float32), array([0.31446603, 0.02888083], dtype=float32), array([0.36687702, 0.0336943 ], dtype=float32), array([0.41151342, 0.03779374], dtype=float32), array([0.46414185, 0.03836958], dtype=float32), array([0.51677024, 0.03894541], dtype=float32), array([0.5693987 , 0.03952125], dtype=float32), array([0.62202716, 0.04009709], dtype=float32), array([0.67465556, 0.04067292], dtype=float32), array([0.72728395, 0.04124876], dtype=float32), array([0.7799124 , 0.04182459], dtype=float32), array([0.83254087, 0.04240043], dtype=float32), array([0.8340862 , 0.04241734], dtype=float32), array([0.8865944 , 0.04601858], dtype=float32), array([0.93910265, 0.04961981], dtype=float32), array([0.9916109 , 0.05322105], dtype=float32), array([1.0441191 , 0.05682229], dtype=float32), array([1.0966274 , 0.06042352], dtype=float32), array([1.1491356 , 0.06402476], dtype=float32), array([1.2016438 , 0.06762599], dtype=float32), array([1.254152 , 0.07122723], dtype=float32), array([1.3066603 , 0.07482847], dtype=float32), array([1.3395048 , 0.07708108], dtype=float32), array([1.3918142 , 0.08289561], dtype=float32), array([1.4441236 , 0.08871014], dtype=float32), array([1.496433 , 0.09452467], dtype=float32), array([1.5487425, 0.1003392], dtype=float32), array([1.6010519 , 0.10615373], dtype=float32), array([1.639276 , 0.11040258], dtype=float32), array([1.6908237 , 0.12102918], dtype=float32), array([1.7423713 , 0.13165578], dtype=float32), array([1.793919 , 0.14228237], dtype=float32), array([1.8454666 , 0.15290898], dtype=float32), array([1.8970141 , 0.16353557], dtype=float32), array([1.9485618 , 0.17416216], dtype=float32), array([2.0001094 , 0.18478876], dtype=float32), array([2.023407 , 0.18959157], dtype=float32), array([2.074234, 0.175928], dtype=float32), array([2.125061 , 0.16226444], dtype=float32), array([2.175888 , 0.14860086], dtype=float32), array([2.2267153 , 0.13493729], dtype=float32), array([2.2775424 , 0.12127371], dtype=float32), array([2.278301 , 0.12106975], dtype=float32), array([2.3292727 , 0.10795621], dtype=float32), array([2.3802445 , 0.09484266], dtype=float32), array([2.4312162 , 0.08172911], dtype=float32), array([2.482188 , 0.06861557], dtype=float32), array([2.5331597 , 0.05550203], dtype=float32), array([2.5841315 , 0.04238848], dtype=float32), array([2.6351032 , 0.02927494], dtype=float32), array([2.686075 , 0.0161614], dtype=float32), array([2.7370467 , 0.00304785], dtype=float32), array([2.7488935, 0. ], dtype=float32)], 'running_time': 1.6064062770000005, 'num_collision_check': 1336, 'num_node': 1000}\n" + ] + } + ], + "source": [ + "start, goal = np.array([0.] * 2), np.array([np.pi*7/8, 0])\n", + "# start, goal = np.array([0.] * 2), np.array([0, 0])\n", + "seed_everything(42)\n", + "model_args = dict(config_size=env.robot.config_dim, embed_size=32, obs_size=9)\n", + "planner = GNNDynamicPlanner(num_samples=num_samples, max_num_samples=max_num_samples, model_args=model_args, use_bctc=False, stop_when_success=True)\n", + "# load trained weights\n", + "planner.load_model(weights_path)\n", + "result = planner.plan(env, start, goal, timeout=('time', 1000))\n", + "if result.solution:\n", + " print('success', result)\n", + "else:\n", + " print('fail')" + ] }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ - "weights_path = ['data/weights/dynamic/2arms/weights_gnn.pt','data/weights/dynamic/2arms/weights_head.pt']" + "# Visualization\n", + "from time import sleep\n", + "def visualize_traj(env, traj_agent):\n", + " max_len_traj = max([len(obj.trajectory.waypoints) for obj in env.objects])\n", + " \n", + " max_time_obs = max_len_traj-1\n", + " speed = 1 / max_time_obs\n", + " gifs = []\n", + " max_len = max(len(traj_agent.waypoints), max_len_traj)\n", + "\n", + " for timestep in np.linspace(0, max_len-1, 100):\n", + " env.robot.set_config(traj_agent.get_spec(timestep))\n", + " for obj in env.objects:\n", + " obj.set_config_at_time(timestep)\n", + " p.performCollisionDetection()\n", + " sleep(0.1)\n", + " gifs.append(p.getCameraImage(width=360, height=360, lightDirection=[1, 1, 1], shadow=1,\n", + " renderer=p.ER_BULLET_HARDWARE_OPENGL)[2])\n", + " return gifs" ], "metadata": { "collapsed": false, @@ -139,18 +229,23 @@ "execution_count": null, "outputs": [], "source": [ - "start, goal = np.array([0.] * 2), np.array([np.pi*7/8, 0])\n", - "# start, goal = np.array([0.] * 2), np.array([0, 0])\n", - "seed_everything(42)\n", - "model_args = dict(config_size=env.robot.config_dim, embed_size=32, obs_size=9)\n", - "planner = GNNDynamicPlanner(num_samples=num_samples, max_num_samples=max_num_samples, model_args=model_args, use_bctc=False, stop_when_success=True)\n", - "# load trained weights\n", - "planner.load_model(weights_path)\n", - "result = planner.plan(env, start, goal, timeout=('time', 1000))\n", - "if result.solution:\n", - " print('success', result)\n", - "else:\n", - " print('fail')" + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import pybullet as p\n", + "import pybullet_data\n", + "import numpy as np\n", + "from utils.utils import save_gif\n", + "from IPython.display import HTML\n", + "import base64\n", + "from objects.trajectory import WaypointLinearTrajectory\n", + "\n", + "env.load(GUI=True)\n", + "\n", + "traj_agent = WaypointLinearTrajectory(result.solution)\n", + "gifs = visualize_traj(env, traj_agent)\n", + "save_gif(gifs, f'data/visualization/GNN-Dynamic.gif')\n", + "b64 = base64.b64encode(open(f'data/visualization/sipp.gif', 'rb').read()).decode('ascii')\n", + "display(HTML(f''))" ], "metadata": { "collapsed": false, @@ -162,11 +257,68 @@ { "cell_type": "code", "execution_count": null, + "outputs": [], + "source": [], "metadata": { + "collapsed": false, "pycharm": { "name": "#%%\n" } - }, + } + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading dynamic object\n", + "Loading robot from ../data/robot/simple2arm/2dof.urdf\n", + "Robot loaded with item_id 0\n", + "Loading robot from ../data/robot/simple2arm/2dof.urdf\n", + "Robot loaded with item_id 1\n" + ] + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import pybullet as p\n", + "import pybullet_data\n", + "import numpy as np\n", + "from utils.utils import save_gif\n", + "from IPython.display import HTML\n", + "import base64\n", + "from objects.trajectory import WaypointLinearTrajectory\n", + "\n", + "env.load(GUI=True)\n", + "\n", + "traj_agent = WaypointLinearTrajectory(result.solution)\n", + "gifs = visualize_traj(env, traj_agent)\n", + "save_gif(gifs, f'data/visualization/GNN-Dynamic.gif')\n", + "b64 = base64.b64encode(open(f'data/visualization/sipp.gif', 'rb').read()).decode('ascii')\n", + "display(HTML(f''))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [] }