Skip to content

Commit

Permalink
version up
Browse files Browse the repository at this point in the history
  • Loading branch information
cszhangzhen committed Jun 1, 2024
1 parent 8067832 commit c13bd43
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 27 deletions.
51 changes: 25 additions & 26 deletions pygda/nn/reweight_gnn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_sparse import SparseTensor
Expand All @@ -10,11 +9,9 @@
from torch.nn import Parameter
from torch_sparse import SparseTensor
from typing import Optional
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptPairTensor, OptTensor
from torch_sparse import matmul as torch_sparse_matmul
from torch_geometric.nn.conv.gcn_conv import gcn_norm


def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
Expand Down Expand Up @@ -43,7 +40,7 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor:
raise ValueError(f"`{reduce}` reduction is not supported for "
f"`torch.sparse.Tensor`.")

class GCN_reweight(MessagePassing):
class GCN_reweight(pyg_nn.MessagePassing):
r"""The graph convolutional operator from the `"Semi-supervised
Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper
Expand Down Expand Up @@ -103,13 +100,21 @@ class GCN_reweight(MessagePassing):
_cached_edge_index: Optional[OptPairTensor]
_cached_adj_t: Optional[SparseTensor]

def __init__(self, in_channels: int, out_channels: int, aggr: str,
improved: bool = False, cached: bool = False,
add_self_loops: bool = False, normalize: bool = True,
bias: bool = True, **kwargs):
def __init__(
self,
in_channels: int,
out_channels: int,
aggr: str,
improved: bool = False,
cached: bool = False,
add_self_loops: bool = False,
normalize: bool = True,
bias: bool = True,
**kwargs
):

kwargs.setdefault('aggr', "add")
super().__init__(**kwargs, flow ="target_to_source")
# kwargs.setdefault('aggr', "add")
super(GCN_reweight, self).__init__(aggr=aggr, flow ="target_to_source")

self.in_channels = in_channels
self.out_channels = out_channels
Expand All @@ -120,13 +125,11 @@ def __init__(self, in_channels: int, out_channels: int, aggr: str,
self.normalize = False
else:
self.normalize = True
#self.normalize = normalize


self._cached_edge_index = None
self._cached_adj_t = None

self.lin = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot')

if bias:
self.bias = Parameter(torch.Tensor(out_channels))
Expand All @@ -142,12 +145,9 @@ def reset_parameters(self):
self._cached_adj_t = None


def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor=None, lmda = 1) -> Tensor:
""""""
def forward(self, x, edge_index, edge_weight, lmda):
edge_rw = edge_weight
edge_weight = torch.ones_like(edge_rw)
#edge_weight = None
if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
Expand All @@ -174,25 +174,24 @@ def forward(self, x: Tensor, edge_index: Adj,
x = self.lin(x)

# propagate_type: (x: Tensor, edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None, lmda = lmda, edge_rw = edge_rw)
out = self.propagate(edge_index, size=None, x=x, edge_weight=edge_weight, edge_rw=edge_rw, lmda=lmda)

if self.bias is not None:
out = out + self.bias

return out

def message(self, x_j: Tensor, edge_weight: OptTensor, lmda, edge_rw) -> Tensor:
def message(self, x_j, edge_index, edge_weight, edge_rw, lmda):
x_j = (edge_weight.view(-1, 1) * x_j)
x_j = (1-lmda) * x_j + (lmda) * (edge_rw.view(-1, 1) * x_j)
return x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return spmm(adj_t, x, reduce=self.aggr)


class GS_reweight(pyg_nn.MessagePassing):
def __init__(self, in_channels, out_channels, reducer,
normalize_embedding=False):
def __init__(self, in_channels, out_channels, reducer, normalize_embedding=False):
super(GS_reweight, self).__init__(aggr=reducer, flow ="target_to_source")
self.lin = torch.nn.Linear(in_channels, out_channels)
self.agg_lin = torch.nn.Linear(out_channels + in_channels, out_channels)
Expand All @@ -201,12 +200,12 @@ def __init__(self, in_channels, out_channels, reducer,

def forward(self, x, edge_index, edge_weight, lmda):
num_nodes = x.size(0)
return self.propagate(edge_index, size=(num_nodes, num_nodes), x=x, edge_weight = edge_weight, lmda = lmda)
return self.propagate(edge_index, size=(num_nodes, num_nodes), x=x, edge_weight=edge_weight, lmda=lmda)

def message(self, x_j, edge_index, edge_weight, lmda):
x_j = self.lin(x_j)
x_j = (1-lmda) * x_j + (lmda) * (edge_weight.view(-1, 1) * x_j)
#print(lmda)

return x_j

def update(self, aggr_out, x):
Expand Down Expand Up @@ -274,7 +273,7 @@ def __init__(
def forward(self, data, h):
x, edge_index, edge_weight = h, data.edge_index, data.edge_weight
for i, layer in enumerate(self.conv):
x = layer(x, edge_index, edge_weight=edge_weight, lmda = self.lmda)
x = layer(x, edge_index, edge_weight, self.lmda)
# if self.bn and (i != len(self.conv) - 1):
# x = self.bns[i](x)
x = F.relu(x)
Expand Down
2 changes: 1 addition & 1 deletion pygda/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.3'
__version__ = '0.0.4'

0 comments on commit c13bd43

Please sign in to comment.