From c13bd43df63d4e02bfadaba826affbcb39a84ba5 Mon Sep 17 00:00:00 2001 From: cszhangzhen Date: Sat, 1 Jun 2024 03:31:23 +0000 Subject: [PATCH] version up --- pygda/nn/reweight_gnn.py | 51 ++++++++++++++++++++-------------------- pygda/version.py | 2 +- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/pygda/nn/reweight_gnn.py b/pygda/nn/reweight_gnn.py index f576726..b8277f5 100644 --- a/pygda/nn/reweight_gnn.py +++ b/pygda/nn/reweight_gnn.py @@ -1,7 +1,6 @@ import torch from torch import nn import torch.nn.functional as F -from torch_geometric.nn import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_sparse import SparseTensor @@ -10,11 +9,9 @@ from torch.nn import Parameter from torch_sparse import SparseTensor from typing import Optional -from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import zeros from torch_geometric.typing import Adj, OptPairTensor, OptTensor from torch_sparse import matmul as torch_sparse_matmul -from torch_geometric.nn.conv.gcn_conv import gcn_norm def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor: @@ -43,7 +40,7 @@ def spmm(src: Adj, other: Tensor, reduce: str = "sum") -> Tensor: raise ValueError(f"`{reduce}` reduction is not supported for " f"`torch.sparse.Tensor`.") -class GCN_reweight(MessagePassing): +class GCN_reweight(pyg_nn.MessagePassing): r"""The graph convolutional operator from the `"Semi-supervised Classification with Graph Convolutional Networks" `_ paper @@ -103,13 +100,21 @@ class GCN_reweight(MessagePassing): _cached_edge_index: Optional[OptPairTensor] _cached_adj_t: Optional[SparseTensor] - def __init__(self, in_channels: int, out_channels: int, aggr: str, - improved: bool = False, cached: bool = False, - add_self_loops: bool = False, normalize: bool = True, - bias: bool = True, **kwargs): + def __init__( + self, + in_channels: int, + out_channels: int, + aggr: str, + improved: bool = False, + cached: bool = False, + add_self_loops: bool = False, + normalize: bool = True, + bias: bool = True, + **kwargs + ): - kwargs.setdefault('aggr', "add") - super().__init__(**kwargs, flow ="target_to_source") + # kwargs.setdefault('aggr', "add") + super(GCN_reweight, self).__init__(aggr=aggr, flow ="target_to_source") self.in_channels = in_channels self.out_channels = out_channels @@ -120,13 +125,11 @@ def __init__(self, in_channels: int, out_channels: int, aggr: str, self.normalize = False else: self.normalize = True - #self.normalize = normalize - + self._cached_edge_index = None self._cached_adj_t = None - self.lin = Linear(in_channels, out_channels, bias=False, - weight_initializer='glorot') + self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') if bias: self.bias = Parameter(torch.Tensor(out_channels)) @@ -142,12 +145,9 @@ def reset_parameters(self): self._cached_adj_t = None - def forward(self, x: Tensor, edge_index: Adj, - edge_weight: OptTensor=None, lmda = 1) -> Tensor: - """""" + def forward(self, x, edge_index, edge_weight, lmda): edge_rw = edge_weight edge_weight = torch.ones_like(edge_rw) - #edge_weight = None if self.normalize: if isinstance(edge_index, Tensor): cache = self._cached_edge_index @@ -174,15 +174,14 @@ def forward(self, x: Tensor, edge_index: Adj, x = self.lin(x) # propagate_type: (x: Tensor, edge_weight: OptTensor) - out = self.propagate(edge_index, x=x, edge_weight=edge_weight, - size=None, lmda = lmda, edge_rw = edge_rw) + out = self.propagate(edge_index, size=None, x=x, edge_weight=edge_weight, edge_rw=edge_rw, lmda=lmda) if self.bias is not None: out = out + self.bias return out - def message(self, x_j: Tensor, edge_weight: OptTensor, lmda, edge_rw) -> Tensor: + def message(self, x_j, edge_index, edge_weight, edge_rw, lmda): x_j = (edge_weight.view(-1, 1) * x_j) x_j = (1-lmda) * x_j + (lmda) * (edge_rw.view(-1, 1) * x_j) return x_j @@ -190,9 +189,9 @@ def message(self, x_j: Tensor, edge_weight: OptTensor, lmda, edge_rw) -> Tensor: def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: return spmm(adj_t, x, reduce=self.aggr) + class GS_reweight(pyg_nn.MessagePassing): - def __init__(self, in_channels, out_channels, reducer, - normalize_embedding=False): + def __init__(self, in_channels, out_channels, reducer, normalize_embedding=False): super(GS_reweight, self).__init__(aggr=reducer, flow ="target_to_source") self.lin = torch.nn.Linear(in_channels, out_channels) self.agg_lin = torch.nn.Linear(out_channels + in_channels, out_channels) @@ -201,12 +200,12 @@ def __init__(self, in_channels, out_channels, reducer, def forward(self, x, edge_index, edge_weight, lmda): num_nodes = x.size(0) - return self.propagate(edge_index, size=(num_nodes, num_nodes), x=x, edge_weight = edge_weight, lmda = lmda) + return self.propagate(edge_index, size=(num_nodes, num_nodes), x=x, edge_weight=edge_weight, lmda=lmda) def message(self, x_j, edge_index, edge_weight, lmda): x_j = self.lin(x_j) x_j = (1-lmda) * x_j + (lmda) * (edge_weight.view(-1, 1) * x_j) - #print(lmda) + return x_j def update(self, aggr_out, x): @@ -274,7 +273,7 @@ def __init__( def forward(self, data, h): x, edge_index, edge_weight = h, data.edge_index, data.edge_weight for i, layer in enumerate(self.conv): - x = layer(x, edge_index, edge_weight=edge_weight, lmda = self.lmda) + x = layer(x, edge_index, edge_weight, self.lmda) # if self.bn and (i != len(self.conv) - 1): # x = self.bns[i](x) x = F.relu(x) diff --git a/pygda/version.py b/pygda/version.py index 8dbfdad..cba8e59 100644 --- a/pygda/version.py +++ b/pygda/version.py @@ -1 +1 @@ -__version__ = '0.0.3' \ No newline at end of file +__version__ = '0.0.4' \ No newline at end of file