Skip to content

Word2Vec (skip‐gram)

MoonDooo edited this page Jul 6, 2024 · 12 revisions
이미지

Skip gram

Word2Vec은 크게 다음과 같이 나뉘다.

  1. CBOW는 주변 단어로 중심 단어를 예측한다.
  2. Skip gram은 중심 단어를 통해 주변 단어를 예측한다.

word2Vec는 위와 같은 과정을 통해 결과적으로 단어들을 의미상의 벡터로 표현할 수 있다. 이 프로젝트는 의미 부분에서는 skip gram이 더 효과가 좋다고 하여 skip gram을 이용하였다. 출처(p. 7)

skip gram은 히든레이어의 활성화함수로 선형함수를 쓰며 출력 레이어에서 softmax를 사용하며 계산의 용이성과 오차가 잘 반영되는 cross-entropy를 사용한다.


WordVectorInitializer 전체 코드
@Component
@Slf4j
public class WordVectorInitializerImpl implements WordVectorInitializer {

    @Value("${nlp.dimension}")
    protected Integer dimension;
    @Value("${nlp.word2vec.learnRate}")
    private Double learnRate;
    @Value("${nlp.word2vec.window}")
    private Integer windowSize;
    @Value("${nlp.word2vec.epoch}")
    private Integer epoch;

    private final Morpheme morpheme;
    private final DestinationOverviewNounExtractor de;
    private final EmbeddingModel em;
    private final DestinationWordVector destinationWordVector;
    private final WordVectorFileReader wordVectorFileReader;
    private final WordVectorFileWriter wordVectorFileWriter;

    public WordVectorInitializerImpl(Morpheme morpheme, DestinationOverviewNounExtractor de, @Qualifier("skipGram") EmbeddingModel em, DestinationWordVector destinationWordVector, WordVectorFileReader wordVectorFileReader, WordVectorFileWriter wordVectorFileWriter) {
        this.morpheme = morpheme;
        this.de = de;
        this.em = em;
        this.destinationWordVector = destinationWordVector;
        this.wordVectorFileReader = wordVectorFileReader;
        this.wordVectorFileWriter = wordVectorFileWriter;
    }

    @Override
    @Time
    public void initDestinationWordVector(){
        WordVector wordVector = wordVectorFileReader.readWordWeight();
        if (valid(wordVector)){
            log.info("파일로부터 연결강도 설정");
            morpheme.init(wordVector.getMapping());
            destinationWordVector.initData(
                    wordVector.getInputHiddenWeight(),
                    wordVector.getHiddenOutputWeight()
            );
        }else{
            morpheme.init();
            createWordWeight();
            wordVectorFileWriter.saveFile();
        }
    }

    private boolean valid(WordVector wordVector) {
        return wordVector != null && morpheme.isValid(wordVector.getMapping()) && wordVector.getDimension() == dimension;
    }

    private void createWordWeight() {
        WeightBuilder weightBuilder = learningWeightByEmbeddingModel();
        destinationWordVector.initData(
                weightBuilder.getInputHiddenWeight(),
                weightBuilder.getHiddenOutputWeight()
        );
    }

    private WeightBuilder learningWeightByEmbeddingModel() {
        List<List<String>> nounListGroupByDestination = de.findAllNounGroupByDestination();
        return em.learningWeight(
                LearningBuilder.builder()
                        .learnRate(learnRate)
                        .epoch(epoch)
                        .dimension(dimension)
                        .window(windowSize)
                        .documentWordList(nounListGroupByDestination)
                        .build()
        );
    }
}

시작

    @Override
    @Time
    public void initDestinationWordVector(){
        WordVector wordVector = wordVectorFileReader.readWordWeight();
        if (valid(wordVector)){
            log.info("파일로부터 연결강도 설정");
            morpheme.init(wordVector.getMapping());
            destinationWordVector.initData(
                    wordVector.getInputHiddenWeight(),
                    wordVector.getHiddenOutputWeight()
            );
        }else{
            morpheme.init();
            createWordWeight();
            wordVectorFileWriter.saveFile();
        }
    }

    private boolean valid(WordVector wordVector) {
        return wordVector != null && morpheme.isValid(wordVector.getMapping()) && wordVector.getDimension() == dimension;
    }

word2vec의 연결강도는 다음과 같이 2가지 경우로 나뉜다.

  1. yml에 설정한 해당 파일을 불러와 연결강도를 설정한다. ( 이미 학습된 데이터 )
  2. 학습 때문에 프로덕션 환경과 테스트 환경을 자유롭게 변경 가능하다.

skip gram 전체 코드
@Component("skipGram")
@Slf4j
@RequiredArgsConstructor
public class SkipGram implements Word2Vec {
    private final Morpheme morpheme;
    private final Backpropagation backpropagation;

    @Override
    public WeightBuilder learningWeight(LearningBuilder builder) {
        validWeightBuilder(builder);
        for (int i = 0; i < builder.getEpoch(); i++) {
            learn(builder);
        }
        normalized(builder);
        return builder.getWeightBuilder();
    }

    private void normalized(LearningBuilder builder) {
        double[][] vectors = builder.getWeightBuilder().getInputHiddenWeight();
        for (double[] vec : vectors){
            NormalizedVector.normalizedVector(vec);
        }
    }

    @Override
    public double[] forwardPassWithSoftmax(WeightBuilder weightBuilder, int oneHotInput) {
        return backpropagation.forwardPassWithSoftmaxForOneHotEncoding(
                weightBuilder.getInputHiddenWeight(),
                weightBuilder.getHiddenOutputWeight(),
                oneHotInput,
                ActivationFunction.linear()
        );
    }

    private void learn(LearningBuilder builder) {
        for (int i = 0; i< builder.getDocumentWordList().size(); i++){
            List<String> wordByDocument = builder.getDocumentWordList().get(i);
            initIdxAndBackpropagation(builder, wordByDocument);
        }
    }

    private void initIdxAndBackpropagation(LearningBuilder builder, List<String> wordByDocument) {
        for (int windowIdx = 0; windowIdx< wordByDocument.size(); windowIdx++){
            int oneHotInput = getOneHotIdx(wordByDocument.get(windowIdx));
            int startWindow = Math.max(0, windowIdx- builder.getWindow());
            int endWindow = Math.min(wordByDocument.size(), windowIdx+ builder.getWindow()+1);
            for (int k = startWindow; k< endWindow; k++){
                if (k == windowIdx) continue;
                int oneHotOutput = getOneHotIdx(wordByDocument.get(k));
                double[] result = forwardPassWithSoftmax(builder.getWeightBuilder(), oneHotInput);
                learnByBackpropagation(builder, oneHotInput, oneHotOutput, result);
            }
        }
    }

    private int getOneHotIdx(String s) {
        return morpheme.getIdx(s);
    }

    private void learnByBackpropagation(LearningBuilder builder, int oneHotInput, int oneHotOutput, double[] result) {

        backpropagation.learnForOneHotEncodingWithSoftmax(
                builder.getWeightBuilder().getInputHiddenWeight(),
                builder.getWeightBuilder().getHiddenOutputWeight(),
                builder.getLearnRate(),
                result,
                oneHotInput,
                oneHotOutput,
                ActivationFunctionDifferential.linear()
        );

    }

    public void validWeightBuilder(LearningBuilder builder) {
        if (builder.getWeightBuilder()==null){
            initWeightBuilder(builder);
        }else{
            if (builder.getWeightBuilder().getHiddenOutputWeight()==null|| builder.getWeightBuilder().getInputHiddenWeight()==null){
                initWeightBuilder(builder);
            }
        }
    }

    private void initWeightBuilder(LearningBuilder builder) {
        WeightBuilder weightBuilder = initParameter(builder);
        builder.setWeightBuilder(weightBuilder);
    }
    private WeightBuilder initParameter(LearningBuilder builder) {
        double[][] inputHiddenWeight = InitArray.initArrayToRandom(morpheme.getWordIdxMap().size(), builder.getDimension());
        double[][] hiddenOutputWeight = InitArray.initArrayToRandom(builder.getDimension(), morpheme.getWordIdxMap().size());
        return WeightBuilder
                .builder()
                .inputHiddenWeight(inputHiddenWeight)
                .hiddenOutputWeight(hiddenOutputWeight)
                .build();
    }
    
}

learning weight

    @Override
    public WeightBuilder learningWeight(LearningBuilder builder) {
        validWeightBuilder(builder);
        for (int i = 0; i < builder.getEpoch(); i++) {
            learn(builder);
        }
        normalized(builder);
        return builder.getWeightBuilder();
    }

먼저 학습하고자하는 LearningBuilder이 학습 가능한 상태인지 확인한다. 확인된다면 에포크만큼 학습시킨다. 이후 정규화를 시켜 최종적으로 학습된 연결강도를 반환한다.

learn&initIdxAndBackpropagation

    private void learn(LearningBuilder builder) {
        for (int i = 0; i< builder.getDocumentWordList().size(); i++){
            List<String> wordByDocument = builder.getDocumentWordList().get(i);
            initIdxAndBackpropagation(builder, wordByDocument);
        }
    }
    private void initIdxAndBackpropagation(LearningBuilder builder, List<String> wordByDocument) {
        for (int windowIdx = 0; windowIdx< wordByDocument.size(); windowIdx++){
            int oneHotInput = getOneHotIdx(wordByDocument.get(windowIdx));
            int startWindow = Math.max(0, windowIdx- builder.getWindow());
            int endWindow = Math.min(wordByDocument.size(), windowIdx+ builder.getWindow()+1);
            for (int k = startWindow; k< endWindow; k++){
                if (k == windowIdx) continue;
                int oneHotOutput = getOneHotIdx(wordByDocument.get(k));
                double[] result = forwardPassWithSoftmax(builder.getWeightBuilder(), oneHotInput);
                learnByBackpropagation(builder, oneHotInput, oneHotOutput, result);
            }
        }
    }

LearnBuilder 의 DocumentWordList는 현재 학습시키고자 하는 문서의 순서화된 문자열들이 담겨있다. 이후 현재 중심이되는 windowIdx에 해당하는 입력노드 인덱스 oneHotInput 와 그 주변 단어에 해당하는 인덱스를 oneHotOutput 을 순방향 신경망을 통해 계산한다. 이때 생겨난 result는 다시 역전파를 통해 Learningbuilder를 학습시킨다.

순전파

    @Override
    public double[] forwardPassWithSoftmax(WeightBuilder weightBuilder, int oneHotInput) {
        return backpropagation.forwardPassWithSoftmaxForOneHotEncoding(
                weightBuilder.getInputHiddenWeight(),
                weightBuilder.getHiddenOutputWeight(),
                oneHotInput,
                ActivationFunction.linear()
        );
    }

    //Backpropagation - ShallowNeuralNetwork 클래스
    @Override
    public double[] forwardPassWithSoftmaxForOneHotEncoding(
            double[][] inputHiddenWeight,
            double[][] hiddenOutputWeight,
            int oneHotInput,
            Function<Double, Double> activationFunc) {
        double[] result = forwardPassForOneHotEncoding(inputHiddenWeight, hiddenOutputWeight, oneHotInput, activationFunc);
        return ActivationFunction.softmaxArray().apply(result);
    }
    private static double[] forwardPassForOneHotEncoding(
            double[][] inputHiddenWeight,
            double[][] hiddenOutputWeight,
            int oneHotInput,
            Function<Double, Double> activationFunc) {
        int size = hiddenOutputWeight[0].length;
        int dimension = inputHiddenWeight[0].length;
        double[] result = new double[size];
        for (int i = 0; i<dimension; i++){
            for (int j = 0; j<size; j++){
                result[j] += activationFunc.apply(inputHiddenWeight[oneHotInput][i])*hiddenOutputWeight[i][j];
            }
        }
        return result;
    }

재사용성을 높이기 위해서 히든노드의 활성화함수를 함수형 인터페이스로 구현하였다. 하지만 skip-gram인 경우 벡터의 표현식을 늘리기 위해서 기본적으로 선형 함수를 사용하며 히든노드가 1계층이므로 사실상 입력노드와 히든노드의 행렬 곱이다. 이후 ActivationFunctionsoftmax를 출력 노드의 활성화함수로 사용하여 최종적으로 result를 만들어낸다. 이때 ActivationFunction의 구현은 다음과 같다.

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ActivationFunction {
    public static Function<double[], double[]> softmaxArray(){
        return input->{
            double[] softmax = new double[input.length];
            double max = getMax(input);
            double exp = getExp(input, max);
            for (int i = 0; i< input.length; i++){
                softmax[i] = Math.exp(input[i]-max)/exp;
            }
            return softmax;
        };
    }

    public static Function<Double, Double> linear(){
        return input->input;
    }

    private static double getExp(double[] result, double max) {
        double tmp = 0;
        for (double v : result) {
            tmp += Math.exp(v- max);
        }
        return tmp;
    }

    private static double getMax(double[] result) {
        return Arrays.stream(result).max().orElse(Double.NEGATIVE_INFINITY);
    }
}

softmax의 자연상수 제곱이 double의 범위를 넘어서는 경우가 있으므로 최대값을 빼서 계산하였다.

학습

    private void learn(LearningBuilder builder) {
        for (int i = 0; i< builder.getDocumentWordList().size(); i++){
            List<String> wordByDocument = builder.getDocumentWordList().get(i);
            initIdxAndBackpropagation(builder, wordByDocument);
        }
    }

    private void initIdxAndBackpropagation(LearningBuilder builder, List<String> wordByDocument) {
        for (int windowIdx = 0; windowIdx< wordByDocument.size(); windowIdx++){
            int oneHotInput = getOneHotIdx(wordByDocument.get(windowIdx));
            int startWindow = Math.max(0, windowIdx- builder.getWindow());
            int endWindow = Math.min(wordByDocument.size(), windowIdx+ builder.getWindow()+1);
            for (int k = startWindow; k< endWindow; k++){
                if (k == windowIdx) continue;
                int oneHotOutput = getOneHotIdx(wordByDocument.get(k));
                double[] result = forwardPassWithSoftmax(builder.getWeightBuilder(), oneHotInput);
                learnByBackpropagation(builder, oneHotInput, oneHotOutput, result);
            }
        }
    }

    private void learnByBackpropagation(LearningBuilder builder, int oneHotInput, int oneHotOutput, double[] result) {
        backpropagation.learnForOneHotEncodingWithSoftmax(
                builder.getWeightBuilder().getInputHiddenWeight(),
                builder.getWeightBuilder().getHiddenOutputWeight(),
                builder.getLearnRate(),
                result,
                oneHotInput,
                oneHotOutput,
                ActivationFunctionDifferential.linear()
        );
    }
   //Backpropagation - ShallowNeuralNetwork 클래스
    @Override
    public void learnForOneHotEncodingWithSoftmax(
            double[][] inputHiddenWeight,
            double[][] hiddenOutputWeight,
            Double learnRate,
            double[] result,
            int oneHotInput,
            int oneHotOutput,
            Function<Double, Double> activationFunctionDifferential) {
        double[] delta = new double[result.length];
        for (int i =0; i<result.length; i++){
            if (i == oneHotOutput){
                delta[i] = result[i]-1;
            }else{
                delta[i] = result[i];
            }
        }
        for (int i =0; i<inputHiddenWeight[0].length; i++){
            double hiddenDelta = 0;
            for (int j =0; j<result.length; j++){
                hiddenOutputWeight[i][j]-=learnRate*inputHiddenWeight[oneHotInput][i]*delta[j];
                hiddenDelta += hiddenOutputWeight[i][j]*delta[j];
            }
            hiddenDelta *= activationFunctionDifferential.apply(inputHiddenWeight[oneHotInput][i]);
            inputHiddenWeight[oneHotInput][i] -= learnRate*hiddenDelta;
        }
    }
 

학습 방칙으로는 오류 역전파를 사용하므로 cross etropy와 softmax를 편미분하면 (출력값)-(타겟값)이다. 이때 타겟값은 현재 타겟인 출력노드를 제외하고는 0이므로 출력노드일때만 1을 빼준다. 연결강도의 종류를 다음과 같이 분류하여 학습할 수 있다.

  • [히든-출력 연결강도] : -(학습률)(히든노드의 출력)*(델타)이고 (히든노드의 출력) = (1인 입력노드간의 연결강도)이므로 -learnRate*inputHiddenWeight[i][oneHotInput]*delta[j]
  • [입력-히든 연결강도] : 일반적인 델타 법칙으로 인해 선형함수이므로 활성화함수 미분은 생략 (히드노드의 델타)=(출력노드의 델타)*(그 출력노드와 히든노드의 연결강도)의 합이다. 입력노드의 출력은 1로 고정이므로 다음과 같다. inputHiddenWeight[i][oneHotInput] -= learnRate*hiddenDelta
    private void normalized(LearningBuilder builder) {
        double[][] vectors = builder.getWeightBuilder().getInputHiddenWeight();
        for (double[] vec : vectors){
            NormalizedVector.normalizedVector(vec);
        }
    }
    
    public class NormalizedVector {
    public static void normalizedVector(double[] vector){
        double norm = 0;
        for (double v: vector){
            norm += v*v;
        }
        norm = Math.sqrt(norm);
        if (norm == 0) {
            throw new IllegalArgumentException("벡터중에 0이 있으면 안됩니다.");
        }
        for (int i =0; i<vector.length; i++){
            vector[i] = vector[i]/norm;
        }
    }
    public static void normalizedVector(double[][] vectors){
        for (double[] vec : vectors){
            NormalizedVector.normalizedVector(vec);
        }
    }
}

최종적으로 L2 정규화를 통해 벡터의 크기를 1로 만들어 이후의 코사인 유사도 계산에 더 유용하게 만든다.