-
Notifications
You must be signed in to change notification settings - Fork 1
Word2Vec (skip‐gram)
MoonDooo edited this page Nov 14, 2023
·
12 revisions
전체 소스코드
public class Word2Vec2VImpl implements Word2Vec{
protected final int DIMENSION = 100;
private final double learnRate = 0.005;
protected double [][] inputHiddenWeight;
private double [][] hiddenOutputWeight;
protected final Morpheme morpheme;
private final int windowSize = 2;
private final int epoch = 20;
private Integer size;
@Autowired
public Word2Vec2VImpl(Morpheme morpheme){
this.morpheme = morpheme;
}
@Override
public void initData(){
size = morpheme.getWordIdxMap().size();
List<String> nounList = morpheme.findAllNounByDestination();
initArray();
for (int i = 0; i<epoch; i++){
System.out.println("w1 > " + hiddenOutputWeight[0][0]);
System.out.println("w2 > " + inputHiddenWeight[0][0]);
System.out.println("epoch > " + i);
learningWeight(nounList);
}
}
private void learningWeight(List<String> nounList) {
for (int i = 0; i< nounList.size()-windowSize; i++){
int startWindow = Math.max(0, i-windowSize);
int endWindow = Math.min(nounList.size(), i+windowSize+1);
int windowIdx = switch (i) {
case 0 -> 0;
case 1 -> 1;
default -> startWindow+windowSize;
};
int oneHotInput = morpheme.getIdx(nounList.get(windowIdx));
for (int j = startWindow; j<endWindow; j++){
if (j == windowIdx) continue;
int oneHotOutput = morpheme.getIdx(nounList.get(j));
double[] result = forwardPassWithSoftmax(oneHotInput);
learn(result, oneHotInput, oneHotOutput);
}
}
}
private void learn(double[] result, int oneHotInput, int oneHotOutput) {
double[] delta = new double[size];
for (int i =0; i<size; i++){
if (i == oneHotOutput){
delta[i] = result[i]-1;
}else{
delta[i] = result[i];
}
}
for (int i =0; i<DIMENSION; i++){
double hiddenDelta = 0;
for (int j =0; j<size; j++){
hiddenOutputWeight[i][j]-=learnRate*inputHiddenWeight[i][oneHotInput]*delta[j];
hiddenDelta += hiddenOutputWeight[i][j]*delta[j];
}
inputHiddenWeight[i][oneHotInput] -= learnRate*hiddenDelta;
}
}
private double[] forwardPass(int oneHotInput) {
double[] result = new double[size];
for (int i =0; i<DIMENSION; i++){
for (int j = 0; j<size; j++){
result[j] += inputHiddenWeight[i][oneHotInput]*hiddenOutputWeight[i][j];
}
}
return result;
}
private void initArray() {
inputHiddenWeight = new double[DIMENSION][morpheme.getWordIdxMap().size()];
hiddenOutputWeight = new double[DIMENSION][morpheme.getWordIdxMap().size()];
initArrayToRandom(inputHiddenWeight);
initArrayToRandom(hiddenOutputWeight);
}
private void initArrayToRandom(double[][] weightArray) {
Random random = new Random();
int size = morpheme.getWordIdxMap().size();
for (int i = 0; i < DIMENSION; i++) {
for (int j = 0; j < size; j++) {
weightArray[i][j] = -0.5 + random.nextDouble();
}
}
}
public double[] forwardPassWithSoftmax(int oneHotInput){
double[] result = forwardPass(oneHotInput);
double[] softmax = new double[result.length];
double tmp = 0;
double max = 0;
for (double v : result){
if (max<v){
max = v;
}
}
for (double v : result) {
tmp += Math.exp(v-max);
}
for (int i =0; i<result.length; i++){
softmax[i] = Math.exp(result[i]-max)/tmp;
}
return softmax;
}
}
이미지
Word2Vec은 크게 다음과 같이 나뉘다.
- CBOW는 주변 단어로 중심 단어를 예측한다.
- Skip gram은 중심 단어를 통해 주변 단어를 예측한다.
word2Vec는 위와 같은 과정을 통해 결과적으로 단어들을 의미상의 벡터로 표현할 수 있다. 이 프로젝트는 의미 부분에서는 skip gram이 더 효과가 좋다고 하여 skip gram을 이용하였다. 출처(p. 7)
@Override
public void initData(){
size = morpheme.getWordIdxMap().size();
List<String> nounList = morpheme.findAllNounByDestination();
initArray();
for (int i = 0; i<epoch; i++){
learningWeight(nounList);
System.out.println("w1 > " + hiddenOutputWeight[0][0]);
System.out.println("w2 > " + inputHiddenWeight[0][0]);
System.out.println("epoch > " + i);
}
}
출력문은 발산 확인용으로 넣었다.
morpheme은 komoran 형태소 분석 라이브러리를 사용해 데이터로부터 명사를 모은 사전을 만드는 클래스이다. initArray()
는 [히든 노드 개수][사전 크기] 크기의 배열 2개를 작은 랜덤값으로 초기화하고 epoch만큼 연결강도를 학습시킨다.
private void learningWeight(List<String> nounList) {
for (int i = 0; i< nounList.size()-windowSize; i++){
int startWindow = Math.max(0, i-windowSize);
int endWindow = Math.min(nounList.size(), i+windowSize+1);
int windowIdx = switch (i) {
case 0 -> 0;
case 1 -> 1;
default -> startWindow+windowSize;
};
int oneHotInput = morpheme.getIdx(nounList.get(windowIdx));
for (int j = startWindow; j<endWindow; j++){
if (j == windowIdx) continue;
int oneHotOutput = morpheme.getIdx(nounList.get(j));
double[] result = forwardPassWithSoftmax(oneHotInput);
learn(result, oneHotInput, oneHotOutput);
}
}
}
Spring 임경완 |
---|
@ MoonDooo |