From 5b0bf2d18d2ad183dd39333a82996d27f4beffbb Mon Sep 17 00:00:00 2001 From: robby-wang Date: Wed, 17 Jan 2024 11:31:52 -0500 Subject: [PATCH 1/3] Added Ki Readout reset author too. --- mtenn/config.py | 5 +++++ mtenn/readout.py | 49 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/mtenn/config.py b/mtenn/config.py index bf92c73..3ca35c6 100644 --- a/mtenn/config.py +++ b/mtenn/config.py @@ -60,6 +60,7 @@ class ReadoutConfig(StringEnum): """ pic50 = "pic50" + ki = "ki" class CombinationConfig(StringEnum): @@ -203,6 +204,8 @@ def build(self) -> mtenn.model.Model: mtenn_pred_readout = mtenn.readout.PIC50Readout( substrate=self.pred_substrate, Km=self.pred_km ) + case ReadoutConfig.ki: + mtenn_pred_readout = mtenn.readout.KiReadout() case None: mtenn_pred_readout = None @@ -211,6 +214,8 @@ def build(self) -> mtenn.model.Model: mtenn_comb_readout = mtenn.readout.PIC50Readout( substrate=self.comb_substrate, Km=self.comb_km ) + case ReadoutConfig.ki: + mtenn_comb_readout = mtenn.readout.KiReadout() case None: mtenn_comb_readout = None diff --git a/mtenn/readout.py b/mtenn/readout.py index 41dbf2c..419bd01 100644 --- a/mtenn/readout.py +++ b/mtenn/readout.py @@ -4,7 +4,8 @@ class Readout(torch.nn.Module, abc.ABC): - pass + def __str__(self): + return repr(self) class PIC50Readout(Readout): @@ -58,9 +59,6 @@ def __init__(self, substrate: Optional[float] = None, Km: Optional[float] = None def __repr__(self): return f"PIC50Readout(substrate={self.substrate}, Km={self.Km})" - def __str__(self): - return repr(self) - def forward(self, delta_g): """ Method to convert a predicted delta G value into a pIC50 value. @@ -81,3 +79,46 @@ def forward(self, delta_g): pic50 -= torch.log10(torch.tensor(self.cp_val, dtype=delta_g.dtype)) return pic50 + + + +class KiReadout(Readout): + """ + Readout implementation to convert delta G values to Ki values. This new + implementation assumes implicit energy units, WHICH WILL INVALIDATE MODELS TRAINED + PRIOR TO v0.3.0. + Assuming implicit energy units: + deltaG = ln(Ki) + Ki = exp(deltaG) + """ + + def __init__(self): + """ + Initialization. + + Parameters + ---------- + None + """ + super(KiReadout, self).__init__() #TODO Remove? + + def __repr__(self): + return f"KiReadout()" + + def forward(self, delta_g): + """ + Method to convert a predicted delta G value into a Ki value. + + Parameters + ---------- + delta_g : torch.Tensor + Input delta G value. + + Returns + ------- + float + Calculated Ki value. + """ + ki = torch.exp(delta_g).item() + + return ki From 48ffe2e0f1a53b572e3bef621a8caeb3c6cfdaf8 Mon Sep 17 00:00:00 2001 From: Robby Wang <146594460+robby-wang@users.noreply.github.com> Date: Wed, 17 Jan 2024 19:27:05 -0500 Subject: [PATCH 2/3] Applied suggestions from code review, will change the return data type for readout later Co-authored-by: kaminow <51923685+kaminow@users.noreply.github.com> --- mtenn/readout.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/readout.py b/mtenn/readout.py index 419bd01..0110c48 100644 --- a/mtenn/readout.py +++ b/mtenn/readout.py @@ -100,7 +100,7 @@ def __init__(self): ---------- None """ - super(KiReadout, self).__init__() #TODO Remove? + super(KiReadout, self).__init__() def __repr__(self): return f"KiReadout()" @@ -119,6 +119,6 @@ def forward(self, delta_g): float Calculated Ki value. """ - ki = torch.exp(delta_g).item() + ki = torch.exp(delta_g) return ki From 952abbd5d979d081c5923a2cc1e75ca09e77e0ca Mon Sep 17 00:00:00 2001 From: robby-wang Date: Wed, 17 Jan 2024 19:46:31 -0500 Subject: [PATCH 3/3] Changed docstring of readout return type: float -> torch.Tensor --- mtenn/readout.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mtenn/readout.py b/mtenn/readout.py index 0110c48..cebf05c 100644 --- a/mtenn/readout.py +++ b/mtenn/readout.py @@ -70,7 +70,7 @@ def forward(self, delta_g): Returns ------- - float + torch.Tensor Calculated pIC50 value. """ pic50 = -delta_g / torch.log(torch.tensor(10, dtype=delta_g.dtype)) @@ -116,7 +116,7 @@ def forward(self, delta_g): Returns ------- - float + torch.Tensor Calculated Ki value. """ ki = torch.exp(delta_g)