forked from nus-cs2103-AY1718S2/addressbook-level4
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request nus-cs2103-AY1718S2#74 from CS2103JAN2018-T15-B2/v…
…1.4YuxiangSg collatation
- Loading branch information
Showing
34 changed files
with
1,501 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# SoilChang | ||
###### \java\seedu\address\logic\GradientDescent.java | ||
``` java | ||
public class GradientDescent { | ||
public static final String MESSAGE_PREDICTION_SUCCESS = "Prediction success"; | ||
public static final String MESSAGE_PREDICTION_FAIL = "Prediction failed"; | ||
public static final String MESSAGE_PREDICTION_DIVERGENT = "Prediction Solution is not convergent"; | ||
private static GradientDescent instance; | ||
/** | ||
* The weights we are calculating | ||
* For more fields, please add in the default value below | ||
* [ x1= income, x2 = age ] | ||
*/ | ||
private static ArrayList<Double> weights = new ArrayList<>(); | ||
|
||
/** | ||
* Some constants to normalize income and age so that the value don't differ so much | ||
*/ | ||
private static ArrayList<Double> normalizationConstant = | ||
new ArrayList<>(Arrays.asList(1000.0)); | ||
|
||
|
||
private final Logger logger = LogsCenter.getLogger(LogicManager.class); | ||
private PredictionModel model; | ||
/** | ||
* The learning rate | ||
*/ | ||
private final Double learningRate = 0.000001; | ||
|
||
/** | ||
* The amount of epoch of looping through training data | ||
*/ | ||
private final Integer epoch = 300000; | ||
|
||
|
||
/** | ||
* Singleton | ||
* | ||
* @param model | ||
*/ | ||
private GradientDescent(PredictionModel model) { | ||
this.model = model; | ||
} | ||
|
||
``` | ||
###### \java\seedu\address\logic\GradientDescent.java | ||
``` java | ||
public static GradientDescent getInstance(PredictionModel model) { | ||
if (instance == null) { | ||
instance = new GradientDescent(model); | ||
instance.resetWeights(); | ||
} | ||
return instance; | ||
} | ||
|
||
|
||
``` | ||
###### \java\seedu\address\logic\GradientDescent.java | ||
``` java | ||
|
||
/** | ||
* Drives the whole algorithm to solve the problem | ||
*/ | ||
public CommandResult solve() throws CommandException { | ||
/** | ||
* row -> each entry of data | ||
* col -> [ income ] | ||
*/ | ||
ArrayList<ArrayList<Double>> matrix = new ArrayList<>(new ArrayList<>()); | ||
|
||
/** | ||
* The actual outcome, this "ActualSpending" | ||
*/ | ||
ArrayList<Double> targets = new ArrayList<>(); | ||
|
||
|
||
//extract values | ||
this.model.preparePredictionData(matrix, targets, normalizationConstant); | ||
|
||
|
||
//solve | ||
descent(matrix, targets); | ||
|
||
if (this.hasNaN(this.weights)) { | ||
return new CommandResult(String.format(MESSAGE_PREDICTION_DIVERGENT)); | ||
} | ||
//update results | ||
try { | ||
this.model.updatePredictionResult(this.getWeights()); | ||
return new CommandResult(String.format(MESSAGE_PREDICTION_SUCCESS)); | ||
} catch (CommandException e) { | ||
return new CommandResult(String.format(MESSAGE_PREDICTION_FAIL)); | ||
} | ||
} | ||
``` | ||
###### \java\seedu\address\logic\GradientDescent.java | ||
``` java | ||
|
||
/** | ||
* Perform stochastic gradient descent on the input data | ||
*/ | ||
private void descent(ArrayList<ArrayList<Double>> matrix, ArrayList<Double> targets) { | ||
for (int itt = 0; itt < epoch; itt++) { | ||
|
||
//check data validity | ||
if (this.hasNaN(this.weights)) { | ||
this.logger.warning("The solution is not convergent"); | ||
break; | ||
} | ||
|
||
// fixed amount of training iteration | ||
for (int r = 0; r < matrix.size(); r++) { //going through each training data | ||
|
||
ArrayList<Double> row = matrix.get(r); | ||
Double outcome = predict(row); | ||
Double error = targets.get(r) - outcome; | ||
for (int i = 0; i < row.size(); i++) { | ||
Double deltaW = this.learningRate * error * row.get(i); | ||
this.weights.set(i, this.weights.get(i) + deltaW); | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Check if the ArrayList contains NaN | ||
* | ||
* @param weights | ||
* @return | ||
*/ | ||
private boolean hasNaN(ArrayList<Double> weights) { | ||
for (Double i : weights) { | ||
if (i.isNaN()) { | ||
return true; | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
|
||
``` | ||
###### \java\seedu\address\logic\GradientDescent.java | ||
``` java | ||
|
||
/** | ||
* calculate values based on the current weights | ||
*/ | ||
private Double predict(ArrayList<Double> row) { | ||
Double sum = 0.0; | ||
//sum income | ||
for (int i = 0; i < row.size(); i++) { | ||
sum += this.weights.get(i) * row.get(i); | ||
} | ||
|
||
return sum; | ||
} | ||
|
||
|
||
/** | ||
* Weight getter | ||
* | ||
* @return | ||
*/ | ||
public static ArrayList<Double> getWeights() { | ||
//restore the un-normalized value | ||
ArrayList<Double> trueWeights = new ArrayList<>(); | ||
for (int j = 0; j < GradientDescent.weights.size(); j++) { | ||
trueWeights.add(GradientDescent.weights.get(j) / normalizationConstant.get(j)); | ||
} | ||
return trueWeights; | ||
} | ||
|
||
/** | ||
* reset the weights to initial values of zeros | ||
*/ | ||
public static void resetWeights() { | ||
GradientDescent.weights = new ArrayList<>(Arrays.asList(0.0)); | ||
} | ||
} | ||
``` | ||
###### \java\seedu\address\model\PredictionModel.java | ||
``` java | ||
public interface PredictionModel extends Model { | ||
void preparePredictionData(ArrayList<ArrayList<Double>> matrix, ArrayList<Double> targets, | ||
ArrayList<Double> normalizationConstant); | ||
|
||
void updatePredictionResult(ArrayList<Double> weights) | ||
throws CommandException; | ||
} | ||
``` |
Oops, something went wrong.