From 290f719f799ad2739557b78f92b560a3e8904d52 Mon Sep 17 00:00:00 2001 From: Anton Lee Date: Thu, 21 Sep 2023 13:38:27 +1200 Subject: [PATCH] Add PredictionLoggerEvaluator --- .../ClassificationPerformanceEvaluator.java | 3 - .../LearningPerformanceEvaluator.java | 6 +- .../evaluation/PredictionLoggerEvaluator.java | 162 ++++++++++++++++++ .../moa/tasks/EvaluateInterleavedChunks.java | 5 + .../EvaluateInterleavedTestThenTrain.java | 7 +- .../main/java/moa/tasks/EvaluateModel.java | 36 +--- .../tasks/EvaluatePeriodicHeldOutTest.java | 53 +----- .../java/moa/tasks/EvaluatePrequential.java | 45 +---- .../java/moa/tasks/EvaluatePrequentialCV.java | 6 + .../moa/tasks/EvaluatePrequentialDelayed.java | 35 +--- .../tasks/EvaluatePrequentialDelayedCV.java | 6 + .../tasks/EvaluatePrequentialMultiLabel.java | 40 +---- ...aluatePrequentialMultiTargetSemiSuper.java | 24 +-- .../tasks/EvaluatePrequentialRegression.java | 37 +--- .../WriteConfigurationToJupyterNotebook.java | 9 - .../meta/ALPrequentialEvaluationTask.java | 6 +- 16 files changed, 231 insertions(+), 249 deletions(-) create mode 100644 moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java diff --git a/moa/src/main/java/moa/evaluation/ClassificationPerformanceEvaluator.java b/moa/src/main/java/moa/evaluation/ClassificationPerformanceEvaluator.java index 0eb9176c9..cc6acedcf 100644 --- a/moa/src/main/java/moa/evaluation/ClassificationPerformanceEvaluator.java +++ b/moa/src/main/java/moa/evaluation/ClassificationPerformanceEvaluator.java @@ -20,10 +20,7 @@ package moa.evaluation; import com.yahoo.labs.samoa.instances.Instance; -import moa.MOAObject; import moa.core.Example; -import moa.core.Measurement; public interface ClassificationPerformanceEvaluator extends LearningPerformanceEvaluator> { - } diff --git a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java index a7c655be8..3337b9cf7 100644 --- a/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java +++ b/moa/src/main/java/moa/evaluation/LearningPerformanceEvaluator.java @@ -35,7 +35,7 @@ * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 7 $ */ -public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler { +public interface LearningPerformanceEvaluator extends MOAObject, CapabilitiesHandler, AutoCloseable { /** * Resets this evaluator. It must be similar to @@ -66,4 +66,8 @@ default ImmutableCapabilities defineImmutableCapabilities() { return new ImmutableCapabilities(Capability.VIEW_STANDARD); } + @Override + default void close() throws Exception { + // By default an evaluator does nothing when closed. + } } diff --git a/moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java b/moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java new file mode 100644 index 000000000..b597f899e --- /dev/null +++ b/moa/src/main/java/moa/evaluation/PredictionLoggerEvaluator.java @@ -0,0 +1,162 @@ +package moa.evaluation; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.Arrays; +import java.util.zip.GZIPOutputStream; + +import com.github.javacliparser.FileOption; +import com.github.javacliparser.FlagOption; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Prediction; + +import moa.capabilities.Capability; +import moa.capabilities.ImmutableCapabilities; +import moa.core.Example; +import moa.core.Measurement; +import moa.core.ObjectRepository; +import moa.core.Utils; +import moa.options.AbstractOptionHandler; +import moa.options.ClassOption; +import moa.tasks.TaskMonitor; + +public class PredictionLoggerEvaluator extends AbstractOptionHandler + implements ClassificationPerformanceEvaluator { + + private static final long serialVersionUID = 1L; + + private OutputStreamWriter writer; + private int index = 0; + + public FileOption csvFileOption = new FileOption("predictionLog", 'o', + "A file to write comma separated values to.", null, "csv.gzip", true); + + public FlagOption overwrite = new FlagOption("overwrite", 'f', "Overwrite existing file."); + + public ClassOption wrappedEvaluatorOption = new ClassOption("evaluator", 'e', + "Classification performance evaluation method.", ClassificationPerformanceEvaluator.class, + "BasicClassificationPerformanceEvaluator"); + + public FlagOption probabilities = new FlagOption("probabilities", 'p', + "Log probabilities instead of raw predictions."); + + public FlagOption uncompressed = new FlagOption("uncompressed", 'u', + "The output file should be saved uncompressed."); + + private ClassificationPerformanceEvaluator wrappedEvaluator; + + @Override + public String getPurposeString() { + return "Log raw predictions and probabilities to a CSV file, and evaluate using a wrapped evaluator."; + } + + @Override + public void addResult(Example example, double[] classVotes) { + Instance instance = example.getData(); + int predictedClass = Utils.maxIndex(classVotes); + double normalizingFactor = Arrays.stream(classVotes).sum(); + int numClasses = instance.numClasses(); + + if (normalizingFactor == 0) { + normalizingFactor = 1; + } + try { + // If this is the first result, write the header to the top of the file + if (index == 0) + writeHeader(numClasses); + + + // Add row to CSV file + if (instance.classIsMissing() == true) + { + writer.write(String.format("?,%d,", predictedClass)); + } + else + { + int trueClass = (int) instance.classValue(); + writer.write(String.format("%d,%d,", trueClass, predictedClass)); + } + + if (probabilities.isSet()) { + for (int i = 0; i < numClasses; i++) { + double probability = 0.0; + if (i < classVotes.length){ + probability = classVotes[i] / normalizingFactor; + } + writer.write(String.format("%.2f,", probability)); + } + } + + writer.write("\n"); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // Pass result to wrapped evaluator + wrappedEvaluator.addResult(example, classVotes); + index ++; + } + + @Override + public void addResult(Example testInst, Prediction prediction) { + // addResult(testInst, prediction.getVotes()); + throw new RuntimeException("Not implemented"); + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + wrappedEvaluator = (ClassificationPerformanceEvaluator) getPreparedClassOption(wrappedEvaluatorOption); + try { + File file = csvFileOption.getFile(); + if (file.exists() && !overwrite.isSet()) { + throw new RuntimeException( + "File already exists: " + file.getAbsolutePath() + + ". MOA doesn't want to overwrite it."); + } + if (uncompressed.isSet()) + writer = new OutputStreamWriter(new FileOutputStream(file)); + else + writer = new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(file))); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private void writeHeader(int numClasses) throws IOException { + writer.write("true_class,class_prediction,"); + if (probabilities.isSet()) { + for (int i = 0; i < numClasses; i++) { + writer.write(String.format("class_probability_%d,", i)); + } + } + writer.write("\n"); + } + + @Override + public void close() throws Exception { + writer.close(); + } + + @Override + public void reset() { + wrappedEvaluator.reset(); + } + + @Override + public Measurement[] getPerformanceMeasurements() { + return wrappedEvaluator.getPerformanceMeasurements(); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + sb.append(getPurposeString()); + } + + @Override + public ImmutableCapabilities defineImmutableCapabilities() { + return new ImmutableCapabilities(Capability.VIEW_STANDARD); + } +} diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java index 0b01516d8..3bbd41d57 100644 --- a/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java +++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedChunks.java @@ -287,6 +287,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java index e0c69a5bd..a1238a254 100644 --- a/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java +++ b/moa/src/main/java/moa/tasks/EvaluateInterleavedTestThenTrain.java @@ -25,7 +25,6 @@ import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.core.Example; import moa.core.Measurement; @@ -40,7 +39,6 @@ import com.github.javacliparser.IntOption; import moa.streams.ExampleStream; import moa.streams.InstanceStream; -import com.yahoo.labs.samoa.instances.Instance; /** * Task for evaluating a classifier on a stream by testing then training with each example in sequence. @@ -217,6 +215,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluateModel.java b/moa/src/main/java/moa/tasks/EvaluateModel.java index 334e59e78..b882e4308 100644 --- a/moa/src/main/java/moa/tasks/EvaluateModel.java +++ b/moa/src/main/java/moa/tasks/EvaluateModel.java @@ -19,9 +19,6 @@ */ package moa.tasks; -import java.io.File; -import java.io.FileOutputStream; -import java.io.PrintStream; import com.github.javacliparser.FileOption; import com.github.javacliparser.IntOption; import moa.capabilities.CapabilitiesHandler; @@ -32,7 +29,6 @@ import moa.core.Example; import moa.core.Measurement; import moa.core.ObjectRepository; -import moa.core.Utils; import moa.evaluation.LearningEvaluation; import moa.evaluation.LearningPerformanceEvaluator; import moa.evaluation.preview.LearningCurve; @@ -40,7 +36,6 @@ import moa.options.ClassOption; import moa.streams.ExampleStream; import moa.streams.InstanceStream; -import com.yahoo.labs.samoa.instances.Instance; /** * Task for evaluating a static model on a stream. @@ -107,35 +102,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { long instancesProcessed = 0; monitor.setCurrentActivity("Evaluating model...", -1.0); - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } while (stream.hasMoreInstances() && ((maxInstances < 0) || (instancesProcessed < maxInstances))) { Example testInst = (Example) stream.nextInstance();//.copy(); - int trueClass = (int) ((Instance) testInst.getData()).classValue(); - //testInst.setClassMissing(); double[] prediction = model.getVotesForInstance(testInst); - //evaluator.addClassificationAttempt(trueClass, prediction, testInst - // .weight()); - if (outputPredictionFile != null) { - outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," +( - ((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass)); - } evaluator.addResult(testInst, prediction); instancesProcessed++; @@ -169,8 +139,10 @@ public Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { } } } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java b/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java index 3e8a511ae..535c33744 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java +++ b/moa/src/main/java/moa/tasks/EvaluatePeriodicHeldOutTest.java @@ -30,7 +30,6 @@ import com.github.javacliparser.IntOption; import moa.capabilities.Capability; import moa.capabilities.ImmutableCapabilities; -import moa.classifiers.Classifier; import moa.classifiers.MultiClassClassifier; import moa.core.Example; import moa.core.Measurement; @@ -140,12 +139,7 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { } testStream = new CachedInstancesStream(testInstances); } else { - //testStream = (InstanceStream) stream.copy(); testStream = stream; - /*monitor.setCurrentActivity("Skipping test examples...", -1.0); - for (int i = 0; i < testSize; i++) { - stream.nextInstance(); - }*/ } instancesProcessed = 0; TimingUtils.enablePreciseTiming(); @@ -191,10 +185,7 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { break; } Example testInst = (Example) testStream.nextInstance(); //.copy(); - double trueClass = ((Instance) testInst.getData()).classValue(); - //testInst.setClassMissing(); double[] prediction = learner.getVotesForInstance(testInst); - //testInst.setClassValue(trueClass); evaluator.addResult(testInst, prediction); testInstancesProcessed++; if (testInstancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) { @@ -242,49 +233,15 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (monitor.resultPreviewRequested()) { monitor.setLatestResultPreview(learningCurve.copy()); } - // if (learner instanceof HoeffdingTree - // || learner instanceof HoeffdingOptionTree) { - // int numActiveNodes = (int) Measurement.getMeasurementNamed( - // "active learning leaves", - // modelMeasurements).getValue(); - // // exit if tree frozen - // if (numActiveNodes < 1) { - // break; - // } - // int numNodes = (int) Measurement.getMeasurementNamed( - // "tree size (nodes)", modelMeasurements) - // .getValue(); - // if (numNodes == lastNumNodes) { - // noGrowthCount++; - // } else { - // noGrowthCount = 0; - // } - // lastNumNodes = numNodes; - // } else if (learner instanceof OzaBoost || learner instanceof - // OzaBag) { - // double numActiveNodes = Measurement.getMeasurementNamed( - // "[avg] active learning leaves", - // modelMeasurements).getValue(); - // // exit if all trees frozen - // if (numActiveNodes == 0.0) { - // break; - // } - // int numNodes = (int) (Measurement.getMeasurementNamed( - // "[avg] tree size (nodes)", - // learner.getModelMeasurements()).getValue() * Measurement - // .getMeasurementNamed("ensemble size", - // modelMeasurements).getValue()); - // if (numNodes == lastNumNodes) { - // noGrowthCount++; - // } else { - // noGrowthCount = 0; - // } - // lastNumNodes = numNodes; - // } } if (immediateResultStream != null) { immediateResultStream.close(); } + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequential.java b/moa/src/main/java/moa/tasks/EvaluatePrequential.java index 80034890e..72ee47b67 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequential.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequential.java @@ -45,8 +45,6 @@ import com.github.javacliparser.FloatOption; import com.github.javacliparser.IntOption; import moa.streams.ExampleStream; -import com.yahoo.labs.samoa.instances.Instance; -import moa.core.Utils; /** * Task for evaluating a classifier on a stream by testing then training with each example in sequence. @@ -97,9 +95,6 @@ public String getPurposeString() { public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true); - public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o', - "File to append output predictions to.", null, "pred", true); - //New for prequential method DEPRECATED public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000); @@ -123,21 +118,18 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { //New for prequential methods if (evaluator instanceof WindowClassificationPerformanceEvaluator) { - //((WindowClassificationPerformanceEvaluator) evaluator).setWindowWidth(widthOption.getValue()); if (widthOption.getValue() != 1000) { System.out.println("DEPRECATED! Use EvaluatePrequential -e (WindowClassificationPerformanceEvaluator -w " + widthOption.getValue() + ")"); return learningCurve; } } if (evaluator instanceof EWMAClassificationPerformanceEvaluator) { - //((EWMAClassificationPerformanceEvaluator) evaluator).setalpha(alphaOption.getValue()); if (alphaOption.getValue() != .01) { System.out.println("DEPRECATED! Use EvaluatePrequential -e (EWMAClassificationPerformanceEvaluator -a " + alphaOption.getValue() + ")"); return learningCurve; } } if (evaluator instanceof FadingFactorClassificationPerformanceEvaluator) { - //((FadingFactorClassificationPerformanceEvaluator) evaluator).setalpha(alphaOption.getValue()); if (alphaOption.getValue() != .01) { System.out.println("DEPRECATED! Use EvaluatePrequential -e (FadingFactorClassificationPerformanceEvaluator -a " + alphaOption.getValue() + ")"); return learningCurve; @@ -168,23 +160,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { "Unable to open immediate result file: " + dumpFile, ex); } } - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } boolean firstDump = true; boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -194,20 +169,14 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { && ((maxInstances < 0) || (instancesProcessed < maxInstances)) && ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) { Example trainInst = stream.nextInstance(); - Example testInst = (Example) trainInst; //.copy(); - //testInst.setClassMissing(); - double[] prediction = learner.getVotesForInstance(testInst); - // Output prediction - if (outputPredictionFile != null) { - int trueClass = (int) ((Instance) trainInst.getData()).classValue(); - outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," + ( - ((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass)); - } + Example testInst = (Example) trainInst; - //evaluator.addClassificationAttempt(trueClass, prediction, testInst.weight()); + double[] prediction = learner.getVotesForInstance(testInst); evaluator.addResult(testInst, prediction); + learner.trainOnInstance(trainInst); instancesProcessed++; + if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 || stream.hasMoreInstances() == false) { long evaluateTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -267,8 +236,10 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java index 75ed8c46b..567fbc5ff 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialCV.java @@ -257,6 +257,12 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + for (LearningPerformanceEvaluator evaluator : evaluators) + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java index 40768892f..4baa05e25 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayed.java @@ -120,9 +120,6 @@ public String getPurposeString() { public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true); - public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o', - "File to append output predictions to.", null, "pred", true); - //New for prequential method DEPRECATED public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000); @@ -150,21 +147,18 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { //New for prequential methods if (evaluator instanceof WindowClassificationPerformanceEvaluator) { - //((WindowClassificationPerformanceEvaluator) evaluator).setWindowWidth(widthOption.getValue()); if (widthOption.getValue() != 1000) { System.out.println("DEPRECATED! Use EvaluatePrequential -e (WindowClassificationPerformanceEvaluator -w " + widthOption.getValue() + ")"); return learningCurve; } } if (evaluator instanceof EWMAClassificationPerformanceEvaluator) { - //((EWMAClassificationPerformanceEvaluator) evaluator).setalpha(alphaOption.getValue()); if (alphaOption.getValue() != .01) { System.out.println("DEPRECATED! Use EvaluatePrequential -e (EWMAClassificationPerformanceEvaluator -a " + alphaOption.getValue() + ")"); return learningCurve; } } if (evaluator instanceof FadingFactorClassificationPerformanceEvaluator) { - //((FadingFactorClassificationPerformanceEvaluator) evaluator).setalpha(alphaOption.getValue()); if (alphaOption.getValue() != .01) { System.out.println("DEPRECATED! Use EvaluatePrequential -e (FadingFactorClassificationPerformanceEvaluator -a " + alphaOption.getValue() + ")"); return learningCurve; @@ -194,23 +188,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { "Unable to open immediate result file: " + dumpFile, ex); } } - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } boolean firstDump = true; boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -261,12 +238,6 @@ else if((this.initialWindowSizeOption.getValue() - instancesProcessed) < this.de testInstance = ((Instance) currentInst.getData()).copy(); testInst = new InstanceExample(testInstance); - // Output prediction - if (outputPredictionFile != null) { - int trueClass = (int) ((Instance) currentInst.getData()).classValue(); - outputPredictionResultStream.println(Utils.maxIndex(prediction) + "," + ( - ((Instance) testInst.getData()).classIsMissing() == true ? " ? " : trueClass)); - } evaluator.addResult(testInst, prediction); if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 @@ -328,8 +299,10 @@ else if((this.initialWindowSizeOption.getValue() - instancesProcessed) < this.de if (immediateResultStream != null) { immediateResultStream.close(); } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java index 58d79d91f..2f7cee1a6 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialDelayedCV.java @@ -282,6 +282,12 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } + try { + for (LearningPerformanceEvaluator evaluator : evaluators) + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java index 8a5a198a9..2808c59dc 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiLabel.java @@ -100,9 +100,6 @@ public String getPurposeString() { public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true); - public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o', - "File to append output predictions to.", null, "pred", true); - //New for prequential method DEPRECATED public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000); @@ -171,23 +168,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { "Unable to open immediate result file: " + dumpFile, ex); } } - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } boolean firstDump = true; boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -200,27 +180,15 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { Example testInst = (Example) trainInst; //.copy(); - //testInst.setClassMissing(); - //double[] prediction = learner.getVotesForInstance(testInst); - if ( instancesProcessed==0){ learner.trainOnInstance(trainInst); instancesProcessed++; continue; } - //evaluator.addClassificationAttempt(trueClass, prediction, testInst.weight()); - + Prediction prediction = learner.getPredictionForInstance(testInst); - - // Output prediction - if (outputPredictionFile != null) { - double trueClass = ((Instance) trainInst.getData()).classValue(); - outputPredictionResultStream.println(prediction + "," + trueClass); - } - evaluator.addResult(testInst, prediction); - learner.trainOnInstance(trainInst); instancesProcessed++; @@ -286,8 +254,10 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java index 52b24b4dd..02af4bb08 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialMultiTargetSemiSuper.java @@ -178,22 +178,7 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { throw new RuntimeException("Unable to open immediate result file: " + dumpFile, ex); } } - - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream(new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream(new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException("Unable to open prediction result file: " + outputPredictionFile, ex); - } - } - + boolean firstDump = true; boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -402,10 +387,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } - return learningCurve; } } diff --git a/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java b/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java index 26a3c844c..d60a2534f 100644 --- a/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java +++ b/moa/src/main/java/moa/tasks/EvaluatePrequentialRegression.java @@ -99,9 +99,6 @@ public String getPurposeString() { public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true); - public FileOption outputPredictionFileOption = new FileOption("outputPredictionFile", 'o', - "File to append output predictions to.", null, "pred", true); - //New for prequential method DEPRECATED public IntOption widthOption = new IntOption("width", 'w', "Size of Window", 1000); @@ -170,23 +167,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { "Unable to open immediate result file: " + dumpFile, ex); } } - //File for output predictions - File outputPredictionFile = this.outputPredictionFileOption.getFile(); - PrintStream outputPredictionResultStream = null; - if (outputPredictionFile != null) { - try { - if (outputPredictionFile.exists()) { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile, true), true); - } else { - outputPredictionResultStream = new PrintStream( - new FileOutputStream(outputPredictionFile), true); - } - } catch (Exception ex) { - throw new RuntimeException( - "Unable to open prediction result file: " + outputPredictionFile, ex); - } - } boolean firstDump = true; boolean preciseCPUTiming = TimingUtils.enablePreciseTiming(); long evaluateStartTime = TimingUtils.getNanoCPUTimeOfCurrentThread(); @@ -197,18 +177,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { && ((maxSeconds < 0) || (secondsElapsed < maxSeconds))) { Example trainInst = stream.nextInstance(); Example testInst = (Example) trainInst; //.copy(); - //testInst.setClassMissing(); - //double[] prediction = learner.getVotesForInstance(testInst); Prediction prediction = learner.getPredictionForInstance(testInst); - // Output prediction - if (outputPredictionFile != null) { - double trueClass = ((Instance) trainInst.getData()).classValue(); - outputPredictionResultStream.println(prediction + "," + trueClass); - } - - //evaluator.addClassificationAttempt(trueClass, prediction, testInst.weight()); + evaluator.addResult(testInst, prediction); learner.trainOnInstance(trainInst); + instancesProcessed++; if (instancesProcessed % this.sampleFrequencyOption.getValue() == 0 || stream.hasMoreInstances() == false) { @@ -269,8 +242,10 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } - if (outputPredictionResultStream != null) { - outputPredictionResultStream.close(); + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); } return learningCurve; } diff --git a/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java b/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java index 5aa14e8ae..a9b1a6b24 100644 --- a/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java +++ b/moa/src/main/java/moa/tasks/WriteConfigurationToJupyterNotebook.java @@ -104,7 +104,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { learnerString = ((EvaluatePrequential) currentTask).learnerOption.getValueAsCLIString().replace('\\', '/'); evaluatorString = ((EvaluatePrequential) currentTask).evaluatorOption.getValueAsCLIString().replace('\\', '/'); dumpFile = ((EvaluatePrequential) currentTask).dumpFileOption.getFile(); - outputPredictionFile = ((EvaluatePrequential) currentTask).outputPredictionFileOption.getFile(); sampleFrequency = ((EvaluatePrequential) currentTask).sampleFrequencyOption.getValue(); instanceLimit = ((EvaluatePrequential) currentTask).instanceLimitOption.getValue(); @@ -154,7 +153,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { learnerString = ((EvaluatePrequentialDelayed) currentTask).learnerOption.getValueAsCLIString().replace('\\', '/'); evaluatorString = ((EvaluatePrequentialDelayed) currentTask).evaluatorOption.getValueAsCLIString().replace('\\', '/'); dumpFile = ((EvaluatePrequentialDelayed) currentTask).dumpFileOption.getFile(); - outputPredictionFile = ((EvaluatePrequentialDelayed) currentTask).outputPredictionFileOption.getFile(); sampleFrequency = ((EvaluatePrequentialDelayed) currentTask).sampleFrequencyOption.getValue(); instanceLimit = ((EvaluatePrequentialDelayed) currentTask).instanceLimitOption.getValue(); trainOnInitialWindow = ((EvaluatePrequentialDelayed) currentTask).trainOnInitialWindowOption.isSet(); @@ -211,7 +209,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { sampleFrequency = ((EvaluatePrequentialRegression) currentTask).sampleFrequencyOption.getValue(); instanceLimit = ((EvaluatePrequentialRegression) currentTask).instanceLimitOption.getValue(); dumpFile = ((EvaluatePrequentialRegression) currentTask).dumpFileOption.getFile(); - outputPredictionFile = ((EvaluatePrequentialRegression) currentTask).outputPredictionFileOption.getFile(); //New for prequential methods if (getPreparedClassOption(((EvaluatePrequentialRegression) currentTask).evaluatorOption) instanceof WindowClassificationPerformanceEvaluator) { @@ -373,11 +370,6 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if ((currentTask instanceof EvaluatePrequential) || (currentTask instanceof EvaluatePrequentialDelayed) || (currentTask instanceof EvaluatePrequentialRegression)) { - //File for output predictions - if (outputPredictionFile != null) { - nb.getLastCell().addSource("File outputPredictionFile = new File(\"" + - outputPredictionFile.getAbsolutePath().replace('\\', '/')+ "\");"); - } else nb.getLastCell().addSource("File outputPredictionFile = null;"); nb.getLastCell().addSource("PrintStream outputPredictionResultStream = null;") .addSource("if (outputPredictionFile != null) {") @@ -793,7 +785,6 @@ else if (((MainTask) this.task).outputFileOption != null) throw new RuntimeException( "Failed implementing task ", ex); } - return result; } diff --git a/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java b/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java index a2813a1b6..24006d09c 100644 --- a/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java +++ b/moa/src/main/java/moa/tasks/meta/ALPrequentialEvaluationTask.java @@ -270,7 +270,11 @@ protected Object doMainTask(TaskMonitor monitor, ObjectRepository repository) { if (immediateResultStream != null) { immediateResultStream.close(); } - + try { + evaluator.close(); + } catch (Exception ex) { + throw new RuntimeException("Exception closing evaluator", ex); + } return new PreviewCollectionLearningCurveWrapper(learningCurve, this.getClass()); }