diff --git a/src/main/java/com/guokr/simbase/score/CosineSquareSimilarity.java b/src/main/java/com/guokr/simbase/score/CosineSquareSimilarity.java index 8040c3b..7f2dc9d 100644 --- a/src/main/java/com/guokr/simbase/score/CosineSquareSimilarity.java +++ b/src/main/java/com/guokr/simbase/score/CosineSquareSimilarity.java @@ -69,7 +69,9 @@ public float score(String srcVKey, int srcId, int[] source, String tgtVKey, int int len2 = target.length; int idx1 = 0, idx2 = 0; while (idx1 < len1 && idx2 < len2) { - if (source[idx1] == target[idx2]) { + if (source[idx1] < 0 || target[idx2] < 0) { + break; + } else if (source[idx1] == target[idx2]) { scoring += source[idx1 + 1] * target[idx2 + 1]; idx1 += 2; idx2 += 2; diff --git a/src/main/java/com/guokr/simbase/score/JensenShannonDivergence.java b/src/main/java/com/guokr/simbase/score/JensenShannonDivergence.java index 4605772..c856b5c 100644 --- a/src/main/java/com/guokr/simbase/score/JensenShannonDivergence.java +++ b/src/main/java/com/guokr/simbase/score/JensenShannonDivergence.java @@ -81,7 +81,9 @@ public float score(String srcVKey, int srcId, int[] source, String tgtVKey, int int len2 = target.length; int idx1 = 0, idx2 = 0; while (idx1 < len1 && idx2 < len2) { - if (source[idx1] == target[idx2]) { + if (source[idx1] < 0 || target[idx2] < 0) { + break; + } else if (source[idx1] == target[idx2]) { float p = source[idx1 + 1]; float q = target[idx2 + 1]; float m = (p + q) / 2; diff --git a/src/main/java/com/guokr/simbase/store/Basis.java b/src/main/java/com/guokr/simbase/store/Basis.java index eca0c87..2bfef00 100644 --- a/src/main/java/com/guokr/simbase/store/Basis.java +++ b/src/main/java/com/guokr/simbase/store/Basis.java @@ -88,10 +88,8 @@ public void revise(String[] base) { } } - public static float[] densify(int size, int sparseFactor, int[] pairs) { + public static void densify(int size, int sparseFactor, int[] pairs, float[] result) { int length = pairs.length; - float[] result = new float[size]; - int index = 0, cursor = 0; float sum = 0f; while (cursor < size) { @@ -115,10 +113,35 @@ public static float[] densify(int size, int sparseFactor, int[] pairs) { result[i] = ((float) Math.round(1000f / size)) / 1000; } } + } + public static float[] densify(int size, int sparseFactor, int[] pairs) { + float[] result = new float[size]; + densify(size, sparseFactor, pairs, result); return result; } + public static void sparsify(int sparseFactor, float[] distr, int[] result) { + int cursor = 0, sum = 0, idx = 0, length = result.length; + for (float ftmp : distr) { + int itmp = Math.round(ftmp * sparseFactor); + if (itmp > 0) { + result[idx] = cursor; + result[idx + 1] = itmp; + sum += itmp; + } + cursor++; + idx = idx + 2; + } + for (int i = idx; i < length; i++) { + result[i++] = -1; + } + for (int i = 0; i < idx;) { + result[i + 1] = (int) Math.round(((float) result[i + 1]) / sum * sparseFactor); + i += 2; + } + } + public static int[] sparsify(int sparseFactor, float[] distr) { TIntArrayList resultList = new TIntArrayList(); diff --git a/src/main/java/com/guokr/simbase/store/DenseVectorSet.java b/src/main/java/com/guokr/simbase/store/DenseVectorSet.java index f70e661..2c7380a 100644 --- a/src/main/java/com/guokr/simbase/store/DenseVectorSet.java +++ b/src/main/java/com/guokr/simbase/store/DenseVectorSet.java @@ -122,26 +122,30 @@ public void remove(int vecid) { } } + protected void get(int vecid, float[] result) { + float ftmp = 0; + int cursor = 0; + int dim = dimns.get(vecid); + int start = indexer.get(vecid); + while (cursor < result.length) { + if (cursor < dim) { + ftmp = probs.get(start + cursor); + if (ftmp >= 0 && ftmp <= 1) { + result[cursor] = ftmp; + } + } else { + result[cursor] = 0; + } + cursor++; + } + } + @Override public float[] get(int vecid) { float[] result; if (indexer.containsKey(vecid)) { result = new float[this.base.size()]; - float ftmp = 0; - int cursor = 0; - int dim = dimns.get(vecid); - int start = indexer.get(vecid); - while (cursor < result.length) { - if (cursor < dim) { - ftmp = probs.get(start + cursor); - if (ftmp >= 0 && ftmp <= 1) { - result[cursor] = ftmp; - } - } else { - result[cursor] = 0; - } - cursor++; - } + get(vecid, result); } else { result = new float[0]; } @@ -216,9 +220,17 @@ public void accumulate(int vecid, float[] vector) { } } + protected void _get(int vecid, float[] input, int[] result) { + get(vecid, input); + Basis.sparsify(sparseFactor, input, result); + } + @Override public int[] _get(int vecid) { - return Basis.sparsify(sparseFactor, get(vecid)); + int[] result = new int[this.base.size()]; + float[] input = new float[this.base.size()]; + _get(vecid, input, result); + return result; } @Override @@ -245,11 +257,12 @@ public void addListener(VectorSetListener listener) { public void rescore(String key, int vecid, float[] vector, Recommendation rec) { rec.create(vecid); TIntIntIterator iter = indexer.iterator(); + float[] target = new float[this.base.size()]; if (this == rec.source) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - float[] target = get(tgtId); + get(tgtId, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); rec.add(tgtId, vecid, score); @@ -259,7 +272,7 @@ public void rescore(String key, int vecid, float[] vector, Recommendation rec) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - float[] target = get(tgtId); + get(tgtId, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); } @@ -270,11 +283,13 @@ public void rescore(String key, int vecid, float[] vector, Recommendation rec) { public void rescore(String key, int vecid, int[] vector, Recommendation rec) { rec.create(vecid); TIntIntIterator iter = indexer.iterator(); + float[] input = new float[this.base.size()]; + int[] target = new int[this.base.size() * 2]; if (this == rec.source) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - int[] target = _get(tgtId); + _get(tgtId, input, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); rec.add(tgtId, vecid, score); @@ -284,7 +299,7 @@ public void rescore(String key, int vecid, int[] vector, Recommendation rec) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - int[] target = _get(tgtId); + _get(tgtId, input, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); } diff --git a/src/main/java/com/guokr/simbase/store/SparseVectorSet.java b/src/main/java/com/guokr/simbase/store/SparseVectorSet.java index 6f826f0..4e89b48 100644 --- a/src/main/java/com/guokr/simbase/store/SparseVectorSet.java +++ b/src/main/java/com/guokr/simbase/store/SparseVectorSet.java @@ -134,9 +134,17 @@ public void remove(int vecid) { } } + public float[] get(int vecid, int[] input, float[] result) { + _get(vecid, input); + Basis.densify(base.size(), sparseFactor, input, result); + return result; + } + @Override public float[] get(int vecid) { - return Basis.densify(base.size(), sparseFactor, _get(vecid)); + float[] result = new float[base.size()]; + Basis.densify(base.size(), sparseFactor, _get(vecid), result); + return result; } @Override @@ -154,23 +162,24 @@ public void accumulate(int vecid, float[] vector) { _accumulate(vecid, Basis.sparsify(sparseFactor, vector)); } - @Override - public int[] _get(int vecid) { - TIntArrayList resultList = new TIntArrayList(base.size()); - if (indexer.containsKey(vecid)) { - int cursor = indexer.get(vecid); - while (true) { - int pos = (int) probs.get(cursor++); - if (pos < 0) { - break; - } - int val = Math.round(probs.get(cursor++)); - resultList.add(pos); - resultList.add(val); + protected void _get(int vecid, int[] result) { + int cursor = indexer.get(vecid), i = 0; + while (true) { + int pos = (int) probs.get(cursor++); + if (pos < 0) { + break; } + int val = Math.round(probs.get(cursor++)); + result[i++] = pos; + result[i++] = val; } - int[] result = new int[resultList.size()]; - return resultList.toArray(result); + } + + @Override + public int[] _get(int vecid) { + int[] result = new int[base.size() * 2]; + _get(vecid, result); + return result; } @Override @@ -277,11 +286,13 @@ public void addListener(VectorSetListener listener) { public void rescore(String key, int vecid, float[] vector, Recommendation rec) { rec.create(vecid); TIntIntIterator iter = indexer.iterator(); + int[] input = new int[this.base.size() * 2]; + float[] target = new float[this.base.size()]; if (this == rec.source) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - float[] target = get(tgtId); + get(tgtId, input, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); rec.add(tgtId, vecid, score); @@ -291,7 +302,7 @@ public void rescore(String key, int vecid, float[] vector, Recommendation rec) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - float[] target = get(tgtId); + get(tgtId, input, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); } @@ -302,11 +313,12 @@ public void rescore(String key, int vecid, float[] vector, Recommendation rec) { public void rescore(String key, int vecid, int[] vector, Recommendation rec) { rec.create(vecid); TIntIntIterator iter = indexer.iterator(); + int[] target = new int[this.base.size() * 2]; if (this == rec.source) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - int[] target = _get(tgtId); + _get(tgtId, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); rec.add(tgtId, vecid, score); @@ -316,7 +328,7 @@ public void rescore(String key, int vecid, int[] vector, Recommendation rec) { while (iter.hasNext()) { iter.advance(); int tgtId = iter.key(); - int[] target = _get(tgtId); + _get(tgtId, target); float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, target); rec.add(vecid, tgtId, score); }