-
Notifications
You must be signed in to change notification settings - Fork 1
Word2Vec (skip‐gram)
이미지
Word2Vec은 크게 다음과 같이 나뉘다.
- CBOW는 주변 단어로 중심 단어를 예측한다.
- Skip gram은 중심 단어를 통해 주변 단어를 예측한다.
word2Vec는 위와 같은 과정을 통해 결과적으로 단어들을 의미상의 벡터로 표현할 수 있다. 이 프로젝트는 의미 부분에서는 skip gram이 더 효과가 좋다고 하여 skip gram을 이용하였다. 출처(p. 7)
skip gram은 히든레이어의 활성화함수로 선형함수를 쓰며 출력 레이어에서 softmax를 사용하며 계산의 용이성과 오차가 잘 반영되는 cross-entropy를 사용한다.
WordVectorInitializer 전체 코드
@Component
@Slf4j
public class WordVectorInitializerImpl implements WordVectorInitializer {
@Value("${nlp.dimension}")
protected Integer dimension;
@Value("${nlp.word2vec.learnRate}")
private Double learnRate;
@Value("${nlp.word2vec.window}")
private Integer windowSize;
@Value("${nlp.word2vec.epoch}")
private Integer epoch;
private final Morpheme morpheme;
private final DestinationOverviewNounExtractor de;
private final EmbeddingModel em;
private final DestinationWordVector destinationWordVector;
private final WordVectorFileReader wordVectorFileReader;
private final WordVectorFileWriter wordVectorFileWriter;
public WordVectorInitializerImpl(Morpheme morpheme, DestinationOverviewNounExtractor de, @Qualifier("skipGram") EmbeddingModel em, DestinationWordVector destinationWordVector, WordVectorFileReader wordVectorFileReader, WordVectorFileWriter wordVectorFileWriter) {
this.morpheme = morpheme;
this.de = de;
this.em = em;
this.destinationWordVector = destinationWordVector;
this.wordVectorFileReader = wordVectorFileReader;
this.wordVectorFileWriter = wordVectorFileWriter;
}
@Override
@Time
public void initDestinationWordVector(){
WordVector wordVector = wordVectorFileReader.readWordWeight();
if (valid(wordVector)){
log.info("파일로부터 연결강도 설정");
morpheme.init(wordVector.getMapping());
destinationWordVector.initData(
wordVector.getInputHiddenWeight(),
wordVector.getHiddenOutputWeight()
);
}else{
morpheme.init();
createWordWeight();
wordVectorFileWriter.saveFile();
}
}
private boolean valid(WordVector wordVector) {
return wordVector != null && morpheme.isValid(wordVector.getMapping()) && wordVector.getDimension() == dimension;
}
private void createWordWeight() {
WeightBuilder weightBuilder = learningWeightByEmbeddingModel();
destinationWordVector.initData(
weightBuilder.getInputHiddenWeight(),
weightBuilder.getHiddenOutputWeight()
);
}
private WeightBuilder learningWeightByEmbeddingModel() {
List<List<String>> nounListGroupByDestination = de.findAllNounGroupByDestination();
return em.learningWeight(
LearningBuilder.builder()
.learnRate(learnRate)
.epoch(epoch)
.dimension(dimension)
.window(windowSize)
.documentWordList(nounListGroupByDestination)
.build()
);
}
}
@Override
@Time
public void initDestinationWordVector(){
WordVector wordVector = wordVectorFileReader.readWordWeight();
if (valid(wordVector)){
log.info("파일로부터 연결강도 설정");
morpheme.init(wordVector.getMapping());
destinationWordVector.initData(
wordVector.getInputHiddenWeight(),
wordVector.getHiddenOutputWeight()
);
}else{
morpheme.init();
createWordWeight();
wordVectorFileWriter.saveFile();
}
}
private boolean valid(WordVector wordVector) {
return wordVector != null && morpheme.isValid(wordVector.getMapping()) && wordVector.getDimension() == dimension;
}
word2vec의 연결강도는 다음과 같이 2가지 경우로 나뉜다.
- yml에 설정한 해당 파일을 불러와 연결강도를 설정한다. ( 이미 학습된 데이터 )
- 학습 때문에 프로덕션 환경과 테스트 환경을 자유롭게 변경 가능하다.
skip gram 전체 코드
@Component("skipGram")
@Slf4j
@RequiredArgsConstructor
public class SkipGram implements Word2Vec {
private final Morpheme morpheme;
private final Backpropagation backpropagation;
@Override
public WeightBuilder learningWeight(LearningBuilder builder) {
validWeightBuilder(builder);
for (int i = 0; i < builder.getEpoch(); i++) {
learn(builder);
}
normalized(builder);
return builder.getWeightBuilder();
}
private void normalized(LearningBuilder builder) {
double[][] vectors = builder.getWeightBuilder().getInputHiddenWeight();
for (double[] vec : vectors){
NormalizedVector.normalizedVector(vec);
}
}
@Override
public double[] forwardPassWithSoftmax(WeightBuilder weightBuilder, int oneHotInput) {
return backpropagation.forwardPassWithSoftmaxForOneHotEncoding(
weightBuilder.getInputHiddenWeight(),
weightBuilder.getHiddenOutputWeight(),
oneHotInput,
ActivationFunction.linear()
);
}
private void learn(LearningBuilder builder) {
for (int i = 0; i< builder.getDocumentWordList().size(); i++){
List<String> wordByDocument = builder.getDocumentWordList().get(i);
initIdxAndBackpropagation(builder, wordByDocument);
}
}
private void initIdxAndBackpropagation(LearningBuilder builder, List<String> wordByDocument) {
for (int windowIdx = 0; windowIdx< wordByDocument.size(); windowIdx++){
int oneHotInput = getOneHotIdx(wordByDocument.get(windowIdx));
int startWindow = Math.max(0, windowIdx- builder.getWindow());
int endWindow = Math.min(wordByDocument.size(), windowIdx+ builder.getWindow()+1);
for (int k = startWindow; k< endWindow; k++){
if (k == windowIdx) continue;
int oneHotOutput = getOneHotIdx(wordByDocument.get(k));
double[] result = forwardPassWithSoftmax(builder.getWeightBuilder(), oneHotInput);
learnByBackpropagation(builder, oneHotInput, oneHotOutput, result);
}
}
}
private int getOneHotIdx(String s) {
return morpheme.getIdx(s);
}
private void learnByBackpropagation(LearningBuilder builder, int oneHotInput, int oneHotOutput, double[] result) {
backpropagation.learnForOneHotEncodingWithSoftmax(
builder.getWeightBuilder().getInputHiddenWeight(),
builder.getWeightBuilder().getHiddenOutputWeight(),
builder.getLearnRate(),
result,
oneHotInput,
oneHotOutput,
ActivationFunctionDifferential.linear()
);
}
public void validWeightBuilder(LearningBuilder builder) {
if (builder.getWeightBuilder()==null){
initWeightBuilder(builder);
}else{
if (builder.getWeightBuilder().getHiddenOutputWeight()==null|| builder.getWeightBuilder().getInputHiddenWeight()==null){
initWeightBuilder(builder);
}
}
}
private void initWeightBuilder(LearningBuilder builder) {
WeightBuilder weightBuilder = initParameter(builder);
builder.setWeightBuilder(weightBuilder);
}
private WeightBuilder initParameter(LearningBuilder builder) {
double[][] inputHiddenWeight = InitArray.initArrayToRandom(morpheme.getWordIdxMap().size(), builder.getDimension());
double[][] hiddenOutputWeight = InitArray.initArrayToRandom(builder.getDimension(), morpheme.getWordIdxMap().size());
return WeightBuilder
.builder()
.inputHiddenWeight(inputHiddenWeight)
.hiddenOutputWeight(hiddenOutputWeight)
.build();
}
}
@Override
public WeightBuilder learningWeight(LearningBuilder builder) {
validWeightBuilder(builder);
for (int i = 0; i < builder.getEpoch(); i++) {
learn(builder);
}
normalized(builder);
return builder.getWeightBuilder();
}
먼저 학습하고자하는 LearningBuilder
이 학습 가능한 상태인지 확인한다. 확인된다면 에포크만큼 학습시킨다. 이후 정규화를 시켜 최종적으로 학습된 연결강도를 반환한다.
private void learn(LearningBuilder builder) {
for (int i = 0; i< builder.getDocumentWordList().size(); i++){
List<String> wordByDocument = builder.getDocumentWordList().get(i);
initIdxAndBackpropagation(builder, wordByDocument);
}
}
private void initIdxAndBackpropagation(LearningBuilder builder, List<String> wordByDocument) {
for (int windowIdx = 0; windowIdx< wordByDocument.size(); windowIdx++){
int oneHotInput = getOneHotIdx(wordByDocument.get(windowIdx));
int startWindow = Math.max(0, windowIdx- builder.getWindow());
int endWindow = Math.min(wordByDocument.size(), windowIdx+ builder.getWindow()+1);
for (int k = startWindow; k< endWindow; k++){
if (k == windowIdx) continue;
int oneHotOutput = getOneHotIdx(wordByDocument.get(k));
double[] result = forwardPassWithSoftmax(builder.getWeightBuilder(), oneHotInput);
learnByBackpropagation(builder, oneHotInput, oneHotOutput, result);
}
}
}
LearnBuilder
의 DocumentWordList는 현재 학습시키고자 하는 문서의 순서화된 문자열들이 담겨있다. 이후 현재 중심이되는 windowIdx
에 해당하는 입력노드 인덱스 oneHotInput
와 그 주변 단어에 해당하는 인덱스를 oneHotOutput
을 순방향 신경망을 통해 계산한다. 이때 생겨난 result는 다시 역전파를 통해 Learningbuilder
를 학습시킨다.
@Override
public double[] forwardPassWithSoftmax(WeightBuilder weightBuilder, int oneHotInput) {
return backpropagation.forwardPassWithSoftmaxForOneHotEncoding(
weightBuilder.getInputHiddenWeight(),
weightBuilder.getHiddenOutputWeight(),
oneHotInput,
ActivationFunction.linear()
);
}
//Backpropagation - ShallowNeuralNetwork 클래스
@Override
public double[] forwardPassWithSoftmaxForOneHotEncoding(
double[][] inputHiddenWeight,
double[][] hiddenOutputWeight,
int oneHotInput,
Function<Double, Double> activationFunc) {
double[] result = forwardPassForOneHotEncoding(inputHiddenWeight, hiddenOutputWeight, oneHotInput, activationFunc);
return ActivationFunction.softmaxArray().apply(result);
}
private static double[] forwardPassForOneHotEncoding(
double[][] inputHiddenWeight,
double[][] hiddenOutputWeight,
int oneHotInput,
Function<Double, Double> activationFunc) {
int size = hiddenOutputWeight[0].length;
int dimension = inputHiddenWeight[0].length;
double[] result = new double[size];
for (int i = 0; i<dimension; i++){
for (int j = 0; j<size; j++){
result[j] += activationFunc.apply(inputHiddenWeight[oneHotInput][i])*hiddenOutputWeight[i][j];
}
}
return result;
}
재사용성을 높이기 위해서 히든노드의 활성화함수를 함수형 인터페이스로 구현하였다. 하지만 skip-gram인 경우 벡터의 표현식을 늘리기 위해서 기본적으로 선형 함수를 사용하며 히든노드가 1계층이므로 사실상 입력노드와 히든노드의 행렬 곱이다. 이후 ActivationFunction
의 softmax
를 출력 노드의 활성화함수로 사용하여 최종적으로 result를 만들어낸다. 이때 ActivationFunction
의 구현은 다음과 같다.
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ActivationFunction {
public static Function<double[], double[]> softmaxArray(){
return input->{
double[] softmax = new double[input.length];
double max = getMax(input);
double exp = getExp(input, max);
for (int i = 0; i< input.length; i++){
softmax[i] = Math.exp(input[i]-max)/exp;
}
return softmax;
};
}
public static Function<Double, Double> linear(){
return input->input;
}
private static double getExp(double[] result, double max) {
double tmp = 0;
for (double v : result) {
tmp += Math.exp(v- max);
}
return tmp;
}
private static double getMax(double[] result) {
return Arrays.stream(result).max().orElse(Double.NEGATIVE_INFINITY);
}
}
softmax의 자연상수 제곱이 double의 범위를 넘어서는 경우가 있으므로 최대값을 빼서 계산하였다.
private void learn(LearningBuilder builder) {
for (int i = 0; i< builder.getDocumentWordList().size(); i++){
List<String> wordByDocument = builder.getDocumentWordList().get(i);
initIdxAndBackpropagation(builder, wordByDocument);
}
}
private void initIdxAndBackpropagation(LearningBuilder builder, List<String> wordByDocument) {
for (int windowIdx = 0; windowIdx< wordByDocument.size(); windowIdx++){
int oneHotInput = getOneHotIdx(wordByDocument.get(windowIdx));
int startWindow = Math.max(0, windowIdx- builder.getWindow());
int endWindow = Math.min(wordByDocument.size(), windowIdx+ builder.getWindow()+1);
for (int k = startWindow; k< endWindow; k++){
if (k == windowIdx) continue;
int oneHotOutput = getOneHotIdx(wordByDocument.get(k));
double[] result = forwardPassWithSoftmax(builder.getWeightBuilder(), oneHotInput);
learnByBackpropagation(builder, oneHotInput, oneHotOutput, result);
}
}
}
private void learnByBackpropagation(LearningBuilder builder, int oneHotInput, int oneHotOutput, double[] result) {
backpropagation.learnForOneHotEncodingWithSoftmax(
builder.getWeightBuilder().getInputHiddenWeight(),
builder.getWeightBuilder().getHiddenOutputWeight(),
builder.getLearnRate(),
result,
oneHotInput,
oneHotOutput,
ActivationFunctionDifferential.linear()
);
}
//Backpropagation - ShallowNeuralNetwork 클래스
@Override
public void learnForOneHotEncodingWithSoftmax(
double[][] inputHiddenWeight,
double[][] hiddenOutputWeight,
Double learnRate,
double[] result,
int oneHotInput,
int oneHotOutput,
Function<Double, Double> activationFunctionDifferential) {
double[] delta = new double[result.length];
for (int i =0; i<result.length; i++){
if (i == oneHotOutput){
delta[i] = result[i]-1;
}else{
delta[i] = result[i];
}
}
for (int i =0; i<inputHiddenWeight[0].length; i++){
double hiddenDelta = 0;
for (int j =0; j<result.length; j++){
hiddenOutputWeight[i][j]-=learnRate*inputHiddenWeight[oneHotInput][i]*delta[j];
hiddenDelta += hiddenOutputWeight[i][j]*delta[j];
}
hiddenDelta *= activationFunctionDifferential.apply(inputHiddenWeight[oneHotInput][i]);
inputHiddenWeight[oneHotInput][i] -= learnRate*hiddenDelta;
}
}
학습 방칙으로는 오류 역전파를 사용하므로 cross etropy와 softmax를 편미분하면 (출력값)-(타겟값)이다. 이때 타겟값은 현재 타겟인 출력노드를 제외하고는 0이므로 출력노드일때만 1을 빼준다. 연결강도의 종류를 다음과 같이 분류하여 학습할 수 있다.
- [히든-출력 연결강도] : -(학습률)(히든노드의 출력)*(델타)이고 (히든노드의 출력) = (1인 입력노드간의 연결강도)이므로
-learnRate*inputHiddenWeight[i][oneHotInput]*delta[j]
- [입력-히든 연결강도] : 일반적인 델타 법칙으로 인해
선형함수이므로 활성화함수 미분은 생략(히드노드의 델타)=(출력노드의 델타)*(그 출력노드와 히든노드의 연결강도)의 합이다. 입력노드의 출력은 1로 고정이므로 다음과 같다.inputHiddenWeight[i][oneHotInput] -= learnRate*hiddenDelta
private void normalized(LearningBuilder builder) {
double[][] vectors = builder.getWeightBuilder().getInputHiddenWeight();
for (double[] vec : vectors){
NormalizedVector.normalizedVector(vec);
}
}
public class NormalizedVector {
public static void normalizedVector(double[] vector){
double norm = 0;
for (double v: vector){
norm += v*v;
}
norm = Math.sqrt(norm);
if (norm == 0) {
throw new IllegalArgumentException("벡터중에 0이 있으면 안됩니다.");
}
for (int i =0; i<vector.length; i++){
vector[i] = vector[i]/norm;
}
}
public static void normalizedVector(double[][] vectors){
for (double[] vec : vectors){
NormalizedVector.normalizedVector(vec);
}
}
}
최종적으로 L2 정규화를 통해 벡터의 크기를 1로 만들어 이후의 코사인 유사도 계산에 더 유용하게 만든다.
Spring 임경완 |
---|
@ MoonDooo |