Skip to content

Commit

Permalink
Merge pull request #36 from choderalab/Add_Ki_Readout
Browse files Browse the repository at this point in the history
Added Ki Readout
  • Loading branch information
robby-wang authored Jan 24, 2024
2 parents ee8a09a + 952abbd commit d517110
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
5 changes: 5 additions & 0 deletions mtenn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ReadoutConfig(StringEnum):
"""

pic50 = "pic50"
ki = "ki"


class CombinationConfig(StringEnum):
Expand Down Expand Up @@ -203,6 +204,8 @@ def build(self) -> mtenn.model.Model:
mtenn_pred_readout = mtenn.readout.PIC50Readout(
substrate=self.pred_substrate, Km=self.pred_km
)
case ReadoutConfig.ki:
mtenn_pred_readout = mtenn.readout.KiReadout()
case None:
mtenn_pred_readout = None

Expand All @@ -211,6 +214,8 @@ def build(self) -> mtenn.model.Model:
mtenn_comb_readout = mtenn.readout.PIC50Readout(
substrate=self.comb_substrate, Km=self.comb_km
)
case ReadoutConfig.ki:
mtenn_comb_readout = mtenn.readout.KiReadout()
case None:
mtenn_comb_readout = None

Expand Down
51 changes: 46 additions & 5 deletions mtenn/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


class Readout(torch.nn.Module, abc.ABC):
pass
def __str__(self):
return repr(self)


class PIC50Readout(Readout):
Expand Down Expand Up @@ -58,9 +59,6 @@ def __init__(self, substrate: Optional[float] = None, Km: Optional[float] = None
def __repr__(self):
return f"PIC50Readout(substrate={self.substrate}, Km={self.Km})"

def __str__(self):
return repr(self)

def forward(self, delta_g):
"""
Method to convert a predicted delta G value into a pIC50 value.
Expand All @@ -72,7 +70,7 @@ def forward(self, delta_g):
Returns
-------
float
torch.Tensor
Calculated pIC50 value.
"""
pic50 = -delta_g / torch.log(torch.tensor(10, dtype=delta_g.dtype))
Expand All @@ -81,3 +79,46 @@ def forward(self, delta_g):
pic50 -= torch.log10(torch.tensor(self.cp_val, dtype=delta_g.dtype))

return pic50



class KiReadout(Readout):
"""
Readout implementation to convert delta G values to Ki values. This new
implementation assumes implicit energy units, WHICH WILL INVALIDATE MODELS TRAINED
PRIOR TO v0.3.0.
Assuming implicit energy units:
deltaG = ln(Ki)
Ki = exp(deltaG)
"""

def __init__(self):
"""
Initialization.
Parameters
----------
None
"""
super(KiReadout, self).__init__()

def __repr__(self):
return f"KiReadout()"

def forward(self, delta_g):
"""
Method to convert a predicted delta G value into a Ki value.
Parameters
----------
delta_g : torch.Tensor
Input delta G value.
Returns
-------
torch.Tensor
Calculated Ki value.
"""
ki = torch.exp(delta_g)

return ki

0 comments on commit d517110

Please sign in to comment.