Skip to content

Commit

Permalink
removed unused codes
Browse files Browse the repository at this point in the history
  • Loading branch information
kotori-y committed Jun 5, 2024
1 parent 023abbf commit 2becfbc
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 72 deletions.
74 changes: 14 additions & 60 deletions MTLKcatKM/encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import torch
from torch import nn
from torch.utils.data import RandomSampler
from torch_geometric.loader import DataLoader

from MTLKcatKM.encoder.ligand_encoder import LigandEncoderMoleBert
from MTLKcatKM.encoder.protein_encoder import ProteinEncoder
from MTLKcatKM.encoder.auxiliary_encoder import AuxiliaryEncoder
from MTLKcatKM.layers.attention import MultiHeadAttentionLayer
from MTLKcatKM.layers.mlp import MLP


class ComplexEncoder(nn.Module):
def __init__(
self, ligand_enc: LigandEncoderMoleBert, protein_enc: ProteinEncoder,
auxiliary_enc: AuxiliaryEncoder, use_esm2=False,
dropout=0.2, use_attention=False, use_ph=True, use_temperature=True, use_organism=True,
atten_heads=16, device=None
use_ph=True, use_temperature=True, use_organism=True, device=None
):
super().__init__()
self.device = device
Expand All @@ -39,19 +34,19 @@ def __init__(
self.ligand_hidden = self.ligand_enc.embed_dim
self.protein_hidden = self.protein_enc.embed_dim

self.use_attention = use_attention
if self.use_attention:
self.attn_layer = MultiHeadAttentionLayer(
hid_dim=self.protein_hidden,
n_heads=atten_heads,
dropout=dropout,
device=self.device
)
self.attn_norm_layer = nn.LayerNorm(self.protein_hidden)
self.ff_norm_layer = nn.LayerNorm(self.protein_hidden)

self.ff_layer = MLP(self.protein_hidden, [self.protein_hidden * 2, self.protein_hidden], dropout=dropout)
self.dropout_layer = nn.Dropout(dropout)
# self.use_attention = use_attention
# if self.use_attention:
# self.attn_layer = MultiHeadAttentionLayer(
# hid_dim=self.protein_hidden,
# n_heads=atten_heads,
# dropout=dropout,
# device=self.device
# )
# self.attn_norm_layer = nn.LayerNorm(self.protein_hidden)
# self.ff_norm_layer = nn.LayerNorm(self.protein_hidden)
#
# self.ff_layer = MLP(self.protein_hidden, [self.protein_hidden * 2, self.protein_hidden], dropout=dropout)
# self.dropout_layer = nn.Dropout(dropout)

def make_src_mask(self, src):
# src = [batch size, src len]
Expand All @@ -77,44 +72,3 @@ def forward(self, input_ids, attention_mask, mol_graph, organ_ids, condition):
h_aux = None

return h_lig, h_prot, h_aux


if __name__ == "__main__":
from MTLKcatKM.datasets import MTLKDataset

ligand_model_path = "../pretrained/checkpoints/model.pth"
protein_model_path = "../../ProtTrans/models/prot_t5_xl_uniref50"

dataset_params = {
"sequence_idx": 0,
"smiles_idx": 1,
"label_idx": [7, 8],
"max_length": 512
}

train_dataset = MTLKDataset(data_path='../data/modeling/train_dataset.csv', model_name=protein_model_path, **dataset_params)
sampler_train = RandomSampler(train_dataset)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, 4, drop_last=True
)

train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler_train,
num_workers=10,
)

enc1 = LigandEncoder(ligand_model_path)
enc2 = ProteinEncoder(protein_model_path)
encoder = ComplexEncoder(ligand_enc=enc1, protein_enc=enc2)

for step, batch in enumerate(train_loader):
net_input = {
"mol_graph": batch["mol_graph"],
"input_ids": batch["input_ids"],
"attention_mask": batch["attention_mask"]
}

net_outputs = encoder(**net_input)

break
10 changes: 3 additions & 7 deletions MTLKcatKM/encoder/ligand_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import os

import torch
from torch import nn

import sys
sys.path.append("../..")
from Mole_BERT.model import GNN_graphpred

from MTLKcatKM.encoder.gin import GINet

MODEL_CONFIG = {
"num_layer": 5, # number of graph conv layers
"emb_dim": 300, # embedding dimension in graph conv layers
Expand All @@ -31,17 +27,17 @@ def __init__(self, init_model=None, device=None, frozen_params=False):
self.encoder = GNN_graphpred(5, 300, drop_ratio=0.5)
if init_model is not None and init_model != "":
self.encoder.from_pretrained(init_model, device=self.device)
print("Loaded pre-trained model with success.")
# print("Loaded pre-trained model with success.")

if frozen_params:
print(f"frozen {self}")
# print(f"frozen {self}")
for p in self.encoder.parameters():
p.requires_grad = False
else:
for name, p in self.encoder.named_parameters():
if "encoder.gnn.batch_norms.4" not in name:
p.requires_grad = False
print(f"finetune {self}")
# print(f"finetune {self}")

def forward(self, mol_graph):
if self.frozen_params:
Expand Down
3 changes: 2 additions & 1 deletion MTLKcatKM/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def evaluate(model: MTLModel, loader, device, normalizer: TrainNormalizer, args)

pred_result = {}

pbar = tqdm(loader, desc="Iteration", disable=False)
pbar = tqdm(loader, desc="Iteration", disable=True)

for step, batch in enumerate(pbar):
net_input = {
Expand Down Expand Up @@ -149,6 +149,7 @@ def main(args):

print("Evaluating...")
test_pred = evaluate(model, test_loader, device=device, normalizer=train_normalizer, args=args)
# print("\n") # avoid bug
print(test_pred)
# pred_df = pd.DataFrame(test_pred)
# pred_df.to_csv(f'{args.result_file}', index=False)
Expand Down
3 changes: 1 addition & 2 deletions MTLKcatKM/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def __init__(

self.encoder = ComplexEncoder(
ligand_enc=ligand_enc, protein_enc=protein_enc, auxiliary_enc=auxiliary_enc,
use_ph=use_ph, use_temperature=use_temperature, use_organism=use_organism,
dropout=dropout, device=self.device,
use_ph=use_ph, use_temperature=use_temperature, use_organism=use_organism, device=self.device,
)
self.pro2lig = nn.Linear(protein_enc.embed_dim, ligand_enc.embed_dim)

Expand Down
2 changes: 0 additions & 2 deletions MTLKcatKM/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,7 @@ def main_cli():

parser.add_argument("--max_length", type=int, default=4)

# parser.add_argument("--molclr_path", type=str)
parser.add_argument("--prottrans_path", type=str)
# parser.add_argument("--esm_dir", type=str, default=None)
parser.add_argument("--molebert_dir", type=str, default=None)

parser.add_argument("--tower_hid_layer", type=int, default=1)
Expand Down

0 comments on commit 2becfbc

Please sign in to comment.