Skip to content

Commit

Permalink
init update
Browse files Browse the repository at this point in the history
  • Loading branch information
cszhangzhen committed Jan 31, 2025
1 parent b2c1678 commit ecc6bc2
Show file tree
Hide file tree
Showing 12 changed files with 1,603 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.DS_Store
40 changes: 39 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,39 @@
# GraphATA
# GraphATA
Aggregate to Adapt: Node-Centric Aggregation for Multi-Source-Free Graph Domain Adaptation (WWW-2025).

![](https://github.com/cszhangzhen/GraphATA/blob/main/fig/model.png)

This is a PyTorch implementation of the GraphATA algorithm, which tries to address the multi-source domain adaptation problem without accessing the labelled source graph. Unlike previous multi-source domain adaptation approaches that aggregate predictions at model level, we introduce a novel model named GraphATA which conducts adaptation at node granularity. Specifically, we parameterize each node with its own graph convolutional matrix by automatically aggregating weight matrices from multiple source models according to its local context, thus realizing dynamic adaptation over graph structured data. We also demonstrate the capability of GraphATA to generalize to both model-centric and layer-centric methods.

## Requirements
* python3.8
* pytorch==1.13.1
* torch-scatter==2.1.0
* torch-sparse==0.6.15
* torch-cluster==1.6.0
* torch-geometric==2.4.0
* numpy==1.23.4
* scipy==1.9.3

## Datasets
Datasets used in the paper are all publicly available datasets.

## Quick Start For Node Classification:
Just execuate the following command for source model pre-training:
```
python train_source_node.py
```
Then, execuate the following command for adaptation:
```
python train_target_node.py
```

## Quick Start For Graph Classification:
Just execuate the following command for source model pre-training:
```
python train_source_graph.py
```
Then, execuate the following command for adaptation:
```
python train_target_graph.py
```
Binary file added data.zip
Binary file not shown.
302 changes: 302 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
import os.path as osp
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.io import read_txt_array
import torch.nn.functional as F
import random

import scipy
import pickle as pkl
from sklearn.preprocessing import label_binarize
import csv
import json

import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)


class CitationDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return ["docs.txt", "edgelist.txt", "labels.txt"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def process(self):
edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name))
edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t()

docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name))
f = open(docs_path, 'rb')
content_list = []
for line in f.readlines():
line = str(line, encoding="utf-8")
content_list.append(line.split(","))
x = np.array(content_list, dtype=float)
x = torch.from_numpy(x).to(torch.float)

label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name))
f = open(label_path, 'rb')
content_list = []
for line in f.readlines():
line = str(line, encoding="utf-8")
line = line.replace("\r", "").replace("\n", "")
content_list.append(line)
y = np.array(content_list, dtype=int)
y = torch.from_numpy(y).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])


class TwitchDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return ["edges.csv, features.json, target.csv"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def load_twitch(self, lang):
assert lang in ('DE', 'EN', 'ES', 'FR', 'PTBR', 'RU'), 'Invalid dataset'
filepath = self.raw_dir
label = []
node_ids = []
src = []
targ = []
uniq_ids = set()
print(filepath)
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
node_id = int(row[5])
# handle FR case of non-unique rows
if node_id not in uniq_ids:
uniq_ids.add(node_id)
label.append(int(row[2]=="True"))
node_ids.append(int(row[5]))

node_ids = np.array(node_ids, dtype=np.int)
with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
src.append(int(row[0]))
targ.append(int(row[1]))
with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
j = json.load(f)
src = np.array(src)
targ = np.array(targ)
label = np.array(label)
inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
reorder_node_ids = np.zeros_like(node_ids)
for i in range(label.shape[0]):
reorder_node_ids[i] = inv_node_ids[i]

n = label.shape[0]
A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n))
features = np.zeros((n,3170))
for node, feats in j.items():
if int(node) >= n:
continue
features[int(node), np.array(feats, dtype=int)] = 1
# features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task
new_label = label[reorder_node_ids]
label = new_label

return A, label, features

def process(self):
A, label, features = self.load_twitch(self.name)
A = A.todense() + A.todense().T
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
features = np.array(features)
x = torch.from_numpy(features).to(torch.float)
y = torch.from_numpy(label).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])


class CSBMDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(CSBMDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return [".pkl"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def process(self):
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name))
data = pkl.load(open(path, 'rb'))

data_list = []

random_node_indices = np.random.permutation(data.y.size(0))
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([data.y.size(0)], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([data.y.size(0)], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([data.y.size(0)], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])


class GraphTUDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(GraphTUDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return [".pkl"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def process(self):
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name))
data_list = pkl.load(open(path, 'rb'))
random.shuffle(data_list)

if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]

self.data, self.slices = self.collate(data_list)

torch.save((self.data, self.slices), self.processed_paths[0])
Binary file added fig/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit ecc6bc2

Please sign in to comment.