-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b2c1678
commit ecc6bc2
Showing
12 changed files
with
1,603 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,39 @@ | ||
# GraphATA | ||
# GraphATA | ||
Aggregate to Adapt: Node-Centric Aggregation for Multi-Source-Free Graph Domain Adaptation (WWW-2025). | ||
|
||
 | ||
|
||
This is a PyTorch implementation of the GraphATA algorithm, which tries to address the multi-source domain adaptation problem without accessing the labelled source graph. Unlike previous multi-source domain adaptation approaches that aggregate predictions at model level, we introduce a novel model named GraphATA which conducts adaptation at node granularity. Specifically, we parameterize each node with its own graph convolutional matrix by automatically aggregating weight matrices from multiple source models according to its local context, thus realizing dynamic adaptation over graph structured data. We also demonstrate the capability of GraphATA to generalize to both model-centric and layer-centric methods. | ||
|
||
## Requirements | ||
* python3.8 | ||
* pytorch==1.13.1 | ||
* torch-scatter==2.1.0 | ||
* torch-sparse==0.6.15 | ||
* torch-cluster==1.6.0 | ||
* torch-geometric==2.4.0 | ||
* numpy==1.23.4 | ||
* scipy==1.9.3 | ||
|
||
## Datasets | ||
Datasets used in the paper are all publicly available datasets. | ||
|
||
## Quick Start For Node Classification: | ||
Just execuate the following command for source model pre-training: | ||
``` | ||
python train_source_node.py | ||
``` | ||
Then, execuate the following command for adaptation: | ||
``` | ||
python train_target_node.py | ||
``` | ||
|
||
## Quick Start For Graph Classification: | ||
Just execuate the following command for source model pre-training: | ||
``` | ||
python train_source_graph.py | ||
``` | ||
Then, execuate the following command for adaptation: | ||
``` | ||
python train_target_graph.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
import os.path as osp | ||
import torch | ||
import numpy as np | ||
from torch_geometric.data import InMemoryDataset, Data | ||
from torch_geometric.io import read_txt_array | ||
import torch.nn.functional as F | ||
import random | ||
|
||
import scipy | ||
import pickle as pkl | ||
from sklearn.preprocessing import label_binarize | ||
import csv | ||
import json | ||
|
||
import warnings | ||
warnings.filterwarnings('ignore', category=DeprecationWarning) | ||
|
||
|
||
class CitationDataset(InMemoryDataset): | ||
def __init__(self, | ||
root, | ||
name, | ||
transform=None, | ||
pre_transform=None, | ||
pre_filter=None): | ||
self.name = name | ||
self.root = root | ||
super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter) | ||
|
||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
return ["docs.txt", "edgelist.txt", "labels.txt"] | ||
|
||
@property | ||
def processed_file_names(self): | ||
return ['data.pt'] | ||
|
||
def download(self): | ||
pass | ||
|
||
def process(self): | ||
edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name)) | ||
edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t() | ||
|
||
docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name)) | ||
f = open(docs_path, 'rb') | ||
content_list = [] | ||
for line in f.readlines(): | ||
line = str(line, encoding="utf-8") | ||
content_list.append(line.split(",")) | ||
x = np.array(content_list, dtype=float) | ||
x = torch.from_numpy(x).to(torch.float) | ||
|
||
label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name)) | ||
f = open(label_path, 'rb') | ||
content_list = [] | ||
for line in f.readlines(): | ||
line = str(line, encoding="utf-8") | ||
line = line.replace("\r", "").replace("\n", "") | ||
content_list.append(line) | ||
y = np.array(content_list, dtype=int) | ||
y = torch.from_numpy(y).to(torch.int64) | ||
|
||
data_list = [] | ||
data = Data(edge_index=edge_index, x=x, y=y) | ||
|
||
random_node_indices = np.random.permutation(y.shape[0]) | ||
training_size = int(len(random_node_indices) * 0.8) | ||
val_size = int(len(random_node_indices) * 0.1) | ||
train_node_indices = random_node_indices[:training_size] | ||
val_node_indices = random_node_indices[training_size:training_size + val_size] | ||
test_node_indices = random_node_indices[training_size + val_size:] | ||
|
||
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
train_masks[train_node_indices] = 1 | ||
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
val_masks[val_node_indices] = 1 | ||
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
test_masks[test_node_indices] = 1 | ||
|
||
data.train_mask = train_masks | ||
data.val_mask = val_masks | ||
data.test_mask = test_masks | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
data_list.append(data) | ||
|
||
data, slices = self.collate([data]) | ||
|
||
torch.save((data, slices), self.processed_paths[0]) | ||
|
||
|
||
class TwitchDataset(InMemoryDataset): | ||
def __init__(self, | ||
root, | ||
name, | ||
transform=None, | ||
pre_transform=None, | ||
pre_filter=None): | ||
self.name = name | ||
self.root = root | ||
super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter) | ||
|
||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
return ["edges.csv, features.json, target.csv"] | ||
|
||
@property | ||
def processed_file_names(self): | ||
return ['data.pt'] | ||
|
||
def download(self): | ||
pass | ||
|
||
def load_twitch(self, lang): | ||
assert lang in ('DE', 'EN', 'ES', 'FR', 'PTBR', 'RU'), 'Invalid dataset' | ||
filepath = self.raw_dir | ||
label = [] | ||
node_ids = [] | ||
src = [] | ||
targ = [] | ||
uniq_ids = set() | ||
print(filepath) | ||
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f: | ||
reader = csv.reader(f) | ||
next(reader) | ||
for row in reader: | ||
node_id = int(row[5]) | ||
# handle FR case of non-unique rows | ||
if node_id not in uniq_ids: | ||
uniq_ids.add(node_id) | ||
label.append(int(row[2]=="True")) | ||
node_ids.append(int(row[5])) | ||
|
||
node_ids = np.array(node_ids, dtype=np.int) | ||
with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f: | ||
reader = csv.reader(f) | ||
next(reader) | ||
for row in reader: | ||
src.append(int(row[0])) | ||
targ.append(int(row[1])) | ||
with open(f"{filepath}/musae_{lang}_features.json", 'r') as f: | ||
j = json.load(f) | ||
src = np.array(src) | ||
targ = np.array(targ) | ||
label = np.array(label) | ||
inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)} | ||
reorder_node_ids = np.zeros_like(node_ids) | ||
for i in range(label.shape[0]): | ||
reorder_node_ids[i] = inv_node_ids[i] | ||
|
||
n = label.shape[0] | ||
A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n)) | ||
features = np.zeros((n,3170)) | ||
for node, feats in j.items(): | ||
if int(node) >= n: | ||
continue | ||
features[int(node), np.array(feats, dtype=int)] = 1 | ||
# features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task | ||
new_label = label[reorder_node_ids] | ||
label = new_label | ||
|
||
return A, label, features | ||
|
||
def process(self): | ||
A, label, features = self.load_twitch(self.name) | ||
A = A.todense() + A.todense().T | ||
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long) | ||
features = np.array(features) | ||
x = torch.from_numpy(features).to(torch.float) | ||
y = torch.from_numpy(label).to(torch.int64) | ||
|
||
data_list = [] | ||
data = Data(edge_index=edge_index, x=x, y=y) | ||
|
||
random_node_indices = np.random.permutation(y.shape[0]) | ||
training_size = int(len(random_node_indices) * 0.8) | ||
val_size = int(len(random_node_indices) * 0.1) | ||
train_node_indices = random_node_indices[:training_size] | ||
val_node_indices = random_node_indices[training_size:training_size + val_size] | ||
test_node_indices = random_node_indices[training_size + val_size:] | ||
|
||
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
train_masks[train_node_indices] = 1 | ||
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
val_masks[val_node_indices] = 1 | ||
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
test_masks[test_node_indices] = 1 | ||
|
||
data.train_mask = train_masks | ||
data.val_mask = val_masks | ||
data.test_mask = test_masks | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
data_list.append(data) | ||
|
||
data, slices = self.collate([data]) | ||
|
||
torch.save((data, slices), self.processed_paths[0]) | ||
|
||
|
||
class CSBMDataset(InMemoryDataset): | ||
def __init__(self, | ||
root, | ||
name, | ||
transform=None, | ||
pre_transform=None, | ||
pre_filter=None): | ||
self.name = name | ||
self.root = root | ||
super(CSBMDataset, self).__init__(root, transform, pre_transform, pre_filter) | ||
|
||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
return [".pkl"] | ||
|
||
@property | ||
def processed_file_names(self): | ||
return ['data.pt'] | ||
|
||
def download(self): | ||
pass | ||
|
||
def process(self): | ||
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name)) | ||
data = pkl.load(open(path, 'rb')) | ||
|
||
data_list = [] | ||
|
||
random_node_indices = np.random.permutation(data.y.size(0)) | ||
training_size = int(len(random_node_indices) * 0.8) | ||
val_size = int(len(random_node_indices) * 0.1) | ||
train_node_indices = random_node_indices[:training_size] | ||
val_node_indices = random_node_indices[training_size:training_size + val_size] | ||
test_node_indices = random_node_indices[training_size + val_size:] | ||
|
||
train_masks = torch.zeros([data.y.size(0)], dtype=torch.bool) | ||
train_masks[train_node_indices] = 1 | ||
val_masks = torch.zeros([data.y.size(0)], dtype=torch.bool) | ||
val_masks[val_node_indices] = 1 | ||
test_masks = torch.zeros([data.y.size(0)], dtype=torch.bool) | ||
test_masks[test_node_indices] = 1 | ||
|
||
data.train_mask = train_masks | ||
data.val_mask = val_masks | ||
data.test_mask = test_masks | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
data_list.append(data) | ||
|
||
data, slices = self.collate([data]) | ||
|
||
torch.save((data, slices), self.processed_paths[0]) | ||
|
||
|
||
class GraphTUDataset(InMemoryDataset): | ||
def __init__(self, | ||
root, | ||
name, | ||
transform=None, | ||
pre_transform=None, | ||
pre_filter=None): | ||
self.name = name | ||
self.root = root | ||
super(GraphTUDataset, self).__init__(root, transform, pre_transform, pre_filter) | ||
|
||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
return [".pkl"] | ||
|
||
@property | ||
def processed_file_names(self): | ||
return ['data.pt'] | ||
|
||
def download(self): | ||
pass | ||
|
||
def process(self): | ||
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name)) | ||
data_list = pkl.load(open(path, 'rb')) | ||
random.shuffle(data_list) | ||
|
||
if self.pre_transform is not None: | ||
data_list = [self.pre_transform(data) for data in data_list] | ||
|
||
self.data, self.slices = self.collate(data_list) | ||
|
||
torch.save((self.data, self.slices), self.processed_paths[0]) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.