Skip to content

Commit

Permalink
finish example for static obstacle
Browse files Browse the repository at this point in the history
  • Loading branch information
rainorangelemon committed Mar 3, 2023
1 parent 9e95a90 commit 48d3b65
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 65 deletions.
2 changes: 1 addition & 1 deletion examples/sipp_planner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('gnnmp')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down
148 changes: 99 additions & 49 deletions examples/static_gnn_planner.ipynb

Large diffs are not rendered by default.

24 changes: 13 additions & 11 deletions planner/learned/GNN_static_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from utils.utils import seed_everything, create_dot_dict, to_np
from utils.graphs import knn_graph_from_points

from planner.learned.model.GNN_static import GNNet

from torch_sparse import coalesce
from torch_geometric.nn import knn_graph
from torch_geometric.data import Data
Expand All @@ -12,16 +14,15 @@


class GNNStaticPlanner(LearnedPlanner):
def __init__(self, num_batch, model, k_neighbors=50, **kwargs):
def __init__(self, num_batch, model_args, k_neighbors=50, **kwargs):
self.num_batch = num_batch
self.model = model
self.model = GNNet(**model_args)
self.k_neigbors = k_neighbors
self.num_node = 0

super(GNNStaticPlanner, self).__init__(self.model, **kwargs)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for model_ in self.model:
model_.to(self.device)
self.model.to(self.device)

def _num_node(self):
return self.num_node
Expand All @@ -32,18 +33,19 @@ def _plan(self, env, start, goal, timeout, seed=0, **kwargs):
self.model.eval()
path = self._explore(env, start, goal, self.model, timeout, k=self.k_neigbors, n_sample=self.num_batch)

return create_dot_dict(solution=path)
return create_dot_dict(solution=path if len(path) else None)

def create_graph(self):
graph_data = knn_graph_from_points(self.points, self.k_neighbors)
self.edges = graph_data.edges
self.edge_index = graph_data.edge_index
self.edge_cost = graph_data.edge_cost

def create_data(self, points, edge_index=None, k=50):
def create_data(self, points, obstacles, edge_index=None, k=50):
goal_index = 1
data = Data(goal=torch.FloatTensor(points[goal_index]))
data.v = torch.FloatTensor(points)
data.v = torch.FloatTensor(np.array(points))
data.obstacles = torch.FloatTensor()

if edge_index is not None:
data.edge_index = torch.tensor(edge_index.T).to(self.device)
Expand Down Expand Up @@ -73,7 +75,7 @@ def _explore(self, env, start, goal, model_gnn, timeout, k, n_sample, loop=10):

while not success:

data = self.create_data(points, k)
data = self.create_data(points, env.get_obstacles(), k=k)
self.num_node = len(data.v)
policy = model_gnn(**data.to(self.device).to_dict(), loop=loop)
policy = policy.cpu()
Expand All @@ -89,12 +91,12 @@ def _explore(self, env, start, goal, model_gnn, timeout, k, n_sample, loop=10):
end_a, end_b = int(end_a), int(end_b)
end_a = explored[end_a]
explored_edges.extend([[end_a, end_b], [end_b, end_a]])
if env._edge_fp(to_np(data.v[end_a]), to_np(data.v[end_b])):
if env.edge_fp(to_np(data.v[end_a]), to_np(data.v[end_b])):
explored.append(end_b)
prev[end_b] = end_a

policy[:, end_b] = 0
if env.in_goal_region(to_np(data.v[end_b]), to_np(data.v[1])):
if end_b==1:
success = True
path = [end_b]
node = end_b
Expand All @@ -114,7 +116,7 @@ def _explore(self, env, start, goal, model_gnn, timeout, k, n_sample, loop=10):
new_points = env.sample_n_points(n_sample, need_negative=True)
points = points + list(new_points)

return create_dot_dict(solution=list(data.v[path].data.cpu().numpy()) if len(path) else None)
return list(data.v[path].data.cpu().numpy())



2 changes: 1 addition & 1 deletion planner/learned/model/GNN_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Sigmoid
from torch_geometric.nn import MessagePassing

from base_models import Block
from planner.learned.model.base_models import Block


class MPNN(MessagePassing):
Expand Down
6 changes: 3 additions & 3 deletions planner/learned/model/GNN_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Sigmoid
from torch_geometric.nn import MessagePassing

from base_models import Block
from planner.learned.model.base_models import Block


class MPNN(MessagePassing):
Expand Down Expand Up @@ -73,8 +73,8 @@ def forward(self, v, labels, obstacles, edge_index, loop, **kwargs):
obs_edge_code = self.obs_edge_code(obstacles.view(-1, self.obs_size))

for na, ea in zip(self.node_attentions, self.edge_attentions):
x, obs_node_code = na(x, obs_node_code)
y, obs_edge_code = ea(y, obs_edge_code)
x = na(x, obs_node_code)
y = ea(y, obs_edge_code)

for i in range(loop):
x = self.mpnn(x, edge_index, y)
Expand Down
14 changes: 14 additions & 0 deletions wrappers/obstacles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np

class ObstaclePositionWrapper():
'''
return representation of obstacles as concatenated vector of positions
'''
def __init__(self, baseObject):
self.__class__ = type(baseObject.__class__.__name__,
(self.__class__, baseObject.__class__),
{})
self.__dict__ = baseObject.__dict__

def get_obstacles(self):
return np.array([list(obj.base_position)+list(obj.half_extents) for obj in self.objects])

0 comments on commit 48d3b65

Please sign in to comment.