Skip to content

Word2Vec (skip‐gram)

MoonDooo edited this page Nov 14, 2023 · 12 revisions
전체 소스코드
public class Word2Vec2VImpl implements Word2Vec{
    protected final int DIMENSION = 100;
    private final double learnRate = 0.005;
    protected double [][] inputHiddenWeight;
    private double [][] hiddenOutputWeight;
    protected final Morpheme morpheme;
    private final int windowSize = 2;
    private final int epoch = 20;
    private Integer size;

    @Autowired
    public Word2Vec2VImpl(Morpheme morpheme){
        this.morpheme = morpheme;
    }


    @Override
    public void initData(){
        size = morpheme.getWordIdxMap().size();
        List<String> nounList = morpheme.findAllNounByDestination();
        initArray();
        for (int i = 0; i<epoch; i++){
            System.out.println("w1 > " + hiddenOutputWeight[0][0]);
            System.out.println("w2 > " + inputHiddenWeight[0][0]);
            System.out.println("epoch > " + i);
            learningWeight(nounList);
        }
    }

    private void learningWeight(List<String> nounList) {
        for (int i = 0; i< nounList.size()-windowSize; i++){
            int startWindow = Math.max(0, i-windowSize);
            int endWindow = Math.min(nounList.size(), i+windowSize+1);
            int windowIdx = switch (i) {
                case 0 -> 0;
                case 1 -> 1;
                default -> startWindow+windowSize;
            };
            int oneHotInput = morpheme.getIdx(nounList.get(windowIdx));
            for (int j = startWindow; j<endWindow; j++){
                if (j == windowIdx) continue;
                int oneHotOutput = morpheme.getIdx(nounList.get(j));
                double[] result = forwardPassWithSoftmax(oneHotInput);
                learn(result, oneHotInput, oneHotOutput);
            }
        }
    }

    private void learn(double[] result, int oneHotInput, int oneHotOutput) {
        double[] delta = new double[size];
        for (int i =0; i<size; i++){
            if (i == oneHotOutput){
                delta[i] = result[i]-1;
            }else{
                delta[i] = result[i];
            }
        }
        for (int i =0; i<DIMENSION; i++){
            double hiddenDelta = 0;
            for (int j =0; j<size; j++){
                hiddenOutputWeight[i][j]-=learnRate*inputHiddenWeight[i][oneHotInput]*delta[j];
                hiddenDelta += hiddenOutputWeight[i][j]*delta[j];
            }
            inputHiddenWeight[i][oneHotInput] -= learnRate*hiddenDelta;
        }
    }

    private double[] forwardPass(int oneHotInput) {
        double[] result = new double[size];
        for (int i =0; i<DIMENSION; i++){
            for (int j = 0; j<size; j++){
                result[j] += inputHiddenWeight[i][oneHotInput]*hiddenOutputWeight[i][j];
            }
        }
        return result;
    }

    private void initArray() {
        inputHiddenWeight = new double[DIMENSION][morpheme.getWordIdxMap().size()];
        hiddenOutputWeight = new double[DIMENSION][morpheme.getWordIdxMap().size()];
        initArrayToRandom(inputHiddenWeight);
        initArrayToRandom(hiddenOutputWeight);
    }

    private void initArrayToRandom(double[][] weightArray) {
        Random random  = new Random();
        int size = morpheme.getWordIdxMap().size();
        for (int i = 0; i < DIMENSION; i++) {
            for (int j = 0; j < size; j++) {
                weightArray[i][j] = -0.5 + random.nextDouble();
            }
        }
    }

    public double[] forwardPassWithSoftmax(int oneHotInput){
        double[] result = forwardPass(oneHotInput);
        double[] softmax = new double[result.length];
        double tmp = 0;
        double max = 0;
        for (double v : result){
            if (max<v){
                max = v;
            }
        }
        for (double v : result) {
            tmp += Math.exp(v-max);
        }
        for (int i =0; i<result.length; i++){
            softmax[i] = Math.exp(result[i]-max)/tmp;
        }
        return softmax;
    }
}
이미지

Skip gram

Word2Vec은 크게 다음과 같이 나뉘다.

  1. CBOW는 주변 단어로 중심 단어를 예측한다.
  2. Skip gram은 중심 단어를 통해 주변 단어를 예측한다.

word2Vec는 위와 같은 과정을 통해 결과적으로 단어들을 의미상의 벡터로 표현할 수 있다. 이 프로젝트는 의미 부분에서는 skip gram이 더 효과가 좋다고 하여 skip gram을 이용하였다. 출처(p. 7)

@Override
    public void initData(){
        size = morpheme.getWordIdxMap().size();
        List<String> nounList = morpheme.findAllNounByDestination();
        initArray();
        for (int i = 0; i<epoch; i++){
            learningWeight(nounList);
            System.out.println("w1 > " + hiddenOutputWeight[0][0]);
            System.out.println("w2 > " + inputHiddenWeight[0][0]);
            System.out.println("epoch > " + i);
        }
    }

출력문은 발산 확인용으로 넣었다. morpheme은 komoran 형태소 분석 라이브러리를 사용해 데이터로부터 명사를 모은 사전을 만드는 클래스이다. initArray()는 [히든 노드 개수][사전 크기] 크기의 배열 2개를 작은 랜덤값으로 초기화하고 epoch만큼 연결강도를 학습시킨다.

private void learningWeight(List<String> nounList) {
        for (int i = 0; i< nounList.size()-windowSize; i++){
            int startWindow = Math.max(0, i-windowSize);
            int endWindow = Math.min(nounList.size(), i+windowSize+1);
            int windowIdx = switch (i) {
                case 0 -> 0;
                case 1 -> 1;
                default -> startWindow+windowSize;
            };
            int oneHotInput = morpheme.getIdx(nounList.get(windowIdx));
            for (int j = startWindow; j<endWindow; j++){
                if (j == windowIdx) continue;
                int oneHotOutput = morpheme.getIdx(nounList.get(j));
                double[] result = forwardPassWithSoftmax(oneHotInput);
                learn(result, oneHotInput, oneHotOutput);
            }
        }
    }