diff --git a/tps-core/src/main/scala/tps/GraphParsing.scala b/tps-core/src/main/scala/tps/GraphParsing.scala index 98e89e8a..a468d2f6 100644 --- a/tps-core/src/main/scala/tps/GraphParsing.scala +++ b/tps-core/src/main/scala/tps/GraphParsing.scala @@ -43,4 +43,17 @@ object GraphParsing { } sol } + + def aggregateBySmallestLabelSets[T]( + allLabels: Seq[(Edge, Set[T])] + ): Map[Edge, Set[T]] = { + val edgeToLabelSets = allLabels.groupBy(_._1).map { + case (edge, pairs) => edge -> pairs.map(_._2) + } + for ((edge, labelSets) <- edgeToLabelSets) yield { + val minLabelSetSize = labelSets.map(_.size).min + val labelSetsWithMinSize = labelSets.filter(_.size == minLabelSetSize) + edge -> labelSetsWithMinSize.flatten.toSet + } + } } diff --git a/tps-core/src/main/scala/tps/evaluation/NetworkReferenceComparison.scala b/tps-core/src/main/scala/tps/evaluation/NetworkReferenceComparison.scala index 71905db9..c40b742b 100644 --- a/tps-core/src/main/scala/tps/evaluation/NetworkReferenceComparison.scala +++ b/tps-core/src/main/scala/tps/evaluation/NetworkReferenceComparison.scala @@ -3,7 +3,6 @@ package tps.evaluation import tps._ import tps.Graphs._ import tps.SignedDirectedGraphOps._ -import parsing.DBNNetworkParser import java.io.File import tps.util.MathUtils @@ -33,10 +32,6 @@ object NetworkReferenceComparison { println(aggregateResult) } - def resultString(res: NetworkReferenceComparisonResult): String = { - ??? - } - def aggregateResults( results: Iterable[NetworkReferenceComparisonResult] ): NetworkReferenceComparisonResult = { @@ -48,58 +43,14 @@ object NetworkReferenceComparison { MathUtils.median(results.map(_.nbCandidateEdges)), results.head.nbReferenceEdges, MathUtils.median(results.map(_.nbCommonEdges)), - MathUtils.average(results.map(_.undirectedPrecision)), - MathUtils.average(results.map(_.undirectedRecall)), MathUtils.median(results.map(_.nbDirectedEdges)), - MathUtils.median(results.map(_.directedEdgeRatio)), MathUtils.median(results.map(_.nbCommonDirectedEdges)), - MathUtils.median(results.map(_.commonDirectedEdgeRatio)), MathUtils.median(results.map(_.nbMatchingDirectionEdges)), - MathUtils.median(results.map(_.matchingDirectionEdgeRatio)), MathUtils.median(results.map(_.nbConflictingDirectionEdges)), - MathUtils.median(results.map(_.conflictingDirectionEdgeRatio)), - MathUtils.median(results.map(_.nbUnconfirmedDirectionEdges)), - MathUtils.median(results.map(_.unconfirmedDirectionEdgeRatio)) + MathUtils.median(results.map(_.nbUnconfirmedDirectionEdges)) ) } - // legacy analysis - def runComparativeAnalysis(refFn: String): Unit = { - val refFile = new File(refFn) - val refName = refFile.getName() - val (refNetwork, refEvidence) = ReferenceParser.run(refFile) - - val networkFolder = new File("data/networks/evaluation") - - val dbnFn = s"$networkFolder/DBN.tsv" - val dbnFile = new File(dbnFn) - val dbnMinProb = 0.025 - val dbnNetwork = DBNNetworkParser.run(dbnFile, dbnMinProb) - - // PIN + kin-sub edges - val pinFn = "data/networks/directed-pin-with-resource-edges.tsv" - val pinFile = new File(pinFn) - val pin = toSignedDirectedGraph(PINParser.run(pinFile)) - - // SIF networks - // undirected TXN + TPS (time series + kinase-substrate) - val sifFiles = networkFolder.listFiles() filter { f => - f.getName().endsWith(".sif") - } - val sifNetworks = sifFiles map { f => - f.getName() -> SignedDirectedGraphParser.run(f) - } - - val candidates = Map( - "PIN" -> pin, - "DBN" -> dbnNetwork - ) ++ sifNetworks - - for ((id, network) <- candidates) { - println(resultString(compare(network, refNetwork, id, refName))) - } - } - // Use doubles instead of integers to compute aggregate statistics case class NetworkReferenceComparisonResult( candidateName: String, @@ -107,36 +58,75 @@ object NetworkReferenceComparison { nbCandidateEdges: Double, nbReferenceEdges: Double, nbCommonEdges: Double, - undirectedPrecision: Double, - undirectedRecall: Double, nbDirectedEdges: Double, - directedEdgeRatio: Double, nbCommonDirectedEdges: Double, - commonDirectedEdgeRatio: Double, nbMatchingDirectionEdges: Double, - matchingDirectionEdgeRatio: Double, nbConflictingDirectionEdges: Double, - conflictingDirectionEdgeRatio: Double, - nbUnconfirmedDirectionEdges: Double, - unconfirmedDirectionEdgeRatio: Double - ) + nbUnconfirmedDirectionEdges: Double + ) { + override def toString: String = { + val values = List( + referenceName, + candidateName, + nbReferenceEdges, + nbCandidateEdges, + nbCommonEdges, + metricString(precisionLowerBound(nbCommonEdges, nbCandidateEdges)), + metricString(recall(nbCommonEdges, nbReferenceEdges)), + nbDirectedEdges, + nbCommonDirectedEdges, + nbMatchingDirectionEdges, + nbConflictingDirectionEdges, + nbUnconfirmedDirectionEdges, + metricString( + precisionLowerBound(nbMatchingDirectionEdges, nbCommonDirectedEdges)), + metricString( + precisionUpperBound(nbConflictingDirectionEdges, nbCommonDirectedEdges)), + metricString( + nbMatchingDirectionEdges / nbDirectedEdges * 1000.0), + metricString( + nbConflictingDirectionEdges / nbDirectedEdges * 1000.0) + ) + + values.mkString(",") + } + } + + def metricString(precision: Double): String = { + if (precision.isNaN) "N/A" else precision.toString + } + + def precisionLowerBound( + nbPos: Double, + nbPredictions: Double + ): Double = { + nbPos / nbPredictions + } + + def precisionUpperBound( + nbNeg: Double, + nbPredictions: Double + ): Double = { + (nbPredictions - nbNeg) / nbPredictions + } + + def recall( + nbPos: Double, + nbTotal: Double + ): Double = { + nbPos / nbTotal + } def compare( candidate: SignedDirectedGraph, reference: SignedDirectedGraph, candidateName: String, referenceName: String - ) = { - // compute precision and recall where relevance is whether a selected edge - // is in the reference - + ): NetworkReferenceComparisonResult = { val candidateE = candidate.keySet val referenceE = reference.keySet val commonE = candidateE intersect referenceE - var undirPrecision = commonE.size.toDouble / candidateE.size.toDouble - if (candidateE.isEmpty) undirPrecision = 0.0 - val undirRecall = commonE.size.toDouble / referenceE.size.toDouble val directedE = candidateE filter { e => oneActiveDirection(candidate(e)) @@ -172,18 +162,11 @@ object NetworkReferenceComparison { candidateE.size, referenceE.size, commonE.size, - undirPrecision, - undirRecall, directedE.size, - directedE.size.toDouble / candidateE.size, commonDirectedE.size, - commonDirectedE.size.toDouble / directedE.size.toDouble, matchingDirectionE.size, - matchingDirectionE.size.toDouble / directedE.size.toDouble, conflictingE.size, - conflictingE.size.toDouble / directedE.size.toDouble, - unconfirmedDirectionE.size, - unconfirmedDirectionE.size.toDouble / directedE.size.toDouble + unconfirmedDirectionE.size ) } diff --git a/tps-core/src/main/scala/tps/evaluation/kegg/KEGG2ReferenceConversion.scala b/tps-core/src/main/scala/tps/evaluation/kegg/KEGG2ReferenceConversion.scala new file mode 100644 index 00000000..fbefbe49 --- /dev/null +++ b/tps-core/src/main/scala/tps/evaluation/kegg/KEGG2ReferenceConversion.scala @@ -0,0 +1,24 @@ +package tps.evaluation.kegg + +import java.io.File + +import tps.SignedDirectedNetworkPrinter +import tps.util.FileUtils + +object KEGG2ReferenceConversion { + + def main(args: Array[String]): Unit = { + assert(args.size == 2) + + val inputFile = new File(args(0)) + val outputFile = new File(args(1)) + + val sifNetwork = KEGGParser.parseMostSpecificLabels(inputFile) + + FileUtils.writeToFile( + outputFile, + SignedDirectedNetworkPrinter.print(sifNetwork) + ) + } + +} diff --git a/tps-core/src/main/scala/tps/evaluation/kegg/KEGGParser.scala b/tps-core/src/main/scala/tps/evaluation/kegg/KEGGParser.scala new file mode 100644 index 00000000..1d681328 --- /dev/null +++ b/tps-core/src/main/scala/tps/evaluation/kegg/KEGGParser.scala @@ -0,0 +1,37 @@ +package tps.evaluation.kegg + +import java.io.File + +import tps.GraphParsing +import tps.Graphs.SignedDirectedGraph +import tps.TSVSource + +object KEGGParser { + + def parseMostSpecificLabels(f: File): SignedDirectedGraph = { + val data = new TSVSource(f, noHeaders = false).data + + val pairs = data.tuples map { tuple => + val Seq(id1, tpe, id2) = tuple + val edge = GraphParsing.lexicographicEdge(id1, id2) + + tpe match { + case "activation" => edge -> Set( + GraphParsing.lexicographicActivation(id1, id2) + ) + case "inhibition" => edge -> Set( + GraphParsing.lexicographicInhibition(id1, id2) + ) + case "binding/association" | "expression" => edge -> Set( + GraphParsing.lexicographicActivation(id1, id2), + GraphParsing.lexicographicInhibition(id1, id2), + GraphParsing.lexicographicActivation(id2, id1), + GraphParsing.lexicographicInhibition(id2, id1) + ) + } + } + + GraphParsing.aggregateBySmallestLabelSets(pairs) + } + +}