Skip to content

Commit

Permalink
Merge pull request #99 from Team-KeepGoing/feature/face
Browse files Browse the repository at this point in the history
Feat :: 얼굴 인식 기능 개발
  • Loading branch information
miraexhoi authored Oct 19, 2024
2 parents 4c9f40a + e36d2bc commit 461d037
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 0 deletions.
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ dependencies {

implementation 'org.mapstruct:mapstruct:1.5.5.Final'
annotationProcessor 'org.mapstruct:mapstruct-processor:1.5.5.Final'

implementation 'software.amazon.awssdk:rekognition:2.20.21'
implementation 'software.amazon.awssdk:s3:2.20.21'
}

tasks.named('test') {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.keepgoing.keepserver.domain.face.controller;

import com.keepgoing.keepserver.domain.face.service.FaceRecognitionService;
import com.keepgoing.keepserver.global.common.BaseResponse;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

@Tag(name = "얼굴인식", description = "얼굴인식 관련 api 입니다.")
@RestController
@RequestMapping("/face")
@RequiredArgsConstructor
public class FaceRecognitionController {

private final FaceRecognitionService faceRecognitionService;

@Operation(summary = "얼굴 인식", description = "얼굴을 인식합니다.")
@PostMapping(value = "/compare", consumes = { MediaType.MULTIPART_FORM_DATA_VALUE, MediaType.APPLICATION_JSON_VALUE })
public BaseResponse compareFaces(@RequestPart("image") MultipartFile image) {
return faceRecognitionService.compareFaces(image);
}

@Operation(summary = "얼굴 등록", description = "얼굴을 등록합니다.")
@PostMapping(value = "/register", consumes = { MediaType.MULTIPART_FORM_DATA_VALUE, MediaType.APPLICATION_JSON_VALUE })
public BaseResponse registerFace(@RequestPart("email") String email, @RequestPart("image") MultipartFile image) {
return faceRecognitionService.registerFace(email, image);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.keepgoing.keepserver.domain.face.entity;

import jakarta.persistence.*;
import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
@Entity
@Table(name = "face")
public class Face {

@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;

@Column(nullable = false, unique = true)
private String email;

@Column(name = "s3_image_url")
private String s3ImageUrl;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.keepgoing.keepserver.domain.face.repository;

import com.keepgoing.keepserver.domain.face.entity.Face;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.stereotype.Repository;

import java.util.Optional;

@Repository
public interface FaceRepository extends JpaRepository<Face, Long> {
Optional<Face> findByEmail(String email);

Optional<Face> findByS3ImageUrl(String s3ImageUrl);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.keepgoing.keepserver.domain.face.service;

import com.keepgoing.keepserver.global.common.BaseResponse;
import org.springframework.web.multipart.MultipartFile;

public interface FaceRecognitionService {
BaseResponse compareFaces(MultipartFile sourceImage);

BaseResponse registerFace(String email, MultipartFile image);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package com.keepgoing.keepserver.domain.face.service;

import com.keepgoing.keepserver.domain.face.entity.Face;
import com.keepgoing.keepserver.domain.face.repository.FaceRepository;
import com.keepgoing.keepserver.domain.image.service.ImageService;
import com.keepgoing.keepserver.global.common.BaseResponse;
import com.keepgoing.keepserver.global.common.S3.S3Uploader;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import software.amazon.awssdk.services.rekognition.RekognitionClient;
import software.amazon.awssdk.services.rekognition.model.CompareFacesRequest;
import software.amazon.awssdk.services.rekognition.model.CompareFacesResponse;

import java.util.List;
import java.util.Optional;

@Service
@RequiredArgsConstructor
public class FaceRecognitionServiceImpl implements FaceRecognitionService {

private final ImageService imageService;
private final RekognitionClient rekognitionClient;
private final FaceRepository faceRepository;
private final S3Uploader s3Uploader;

@Override
public BaseResponse compareFaces(MultipartFile sourceImage) {
String tempImageUrl = null;

try {
tempImageUrl = s3Uploader.upload(sourceImage, "face-android");

List<String> studentImageUrls = imageService.getAllImageUrlsFromS3("upload");

for (String studentImageUrl : studentImageUrls) {
CompareFacesRequest request = CompareFacesRequest.builder()
.sourceImage(imageService.getS3Image(tempImageUrl)) // 임시 저장된 이미지 사용
.targetImage(imageService.getS3Image(studentImageUrl)) // 학생 등록 이미지
.similarityThreshold(80F)
.build();

CompareFacesResponse response = rekognitionClient.compareFaces(request);

if (!response.faceMatches().isEmpty()) {
Optional<Face> matchedUser = faceRepository.findByS3ImageUrl(studentImageUrl);
if (matchedUser.isPresent()) {
return new BaseResponse(HttpStatus.OK, "얼굴 인식 성공", matchedUser.get().getEmail());
}
}
}

return new BaseResponse(HttpStatus.OK, "얼굴 등록이 되지 않은 이용자입니다.");
} catch (Exception e) {
return new BaseResponse(HttpStatus.INTERNAL_SERVER_ERROR, "비교 중 오류가 발생했습니다.", e.getMessage());
} finally {
if (tempImageUrl != null) {
s3Uploader.removeFaceAndroidFile(tempImageUrl);
}
}
}

@Override
public BaseResponse registerFace(String email, MultipartFile image) {
try {
Optional<Face> existingUser = faceRepository.findByEmail(email);
if (existingUser.isPresent()) {
return new BaseResponse(HttpStatus.CONFLICT, "이미 해당 이메일로 등록된 사용자가 있습니다.");
}

String s3ImageUrl = s3Uploader.upload(image,"upload");

Face newUser = new Face();
newUser.setEmail(email);
newUser.setS3ImageUrl(s3ImageUrl);
faceRepository.save(newUser);

return new BaseResponse(HttpStatus.OK, "얼굴 등록 성공", email);
} catch (Exception e) {
return new BaseResponse(HttpStatus.INTERNAL_SERVER_ERROR, "얼굴 등록 실패", e.getMessage());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.rekognition.model.Image;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;

@Service
@RequiredArgsConstructor
Expand All @@ -23,5 +27,16 @@ public ImageDTO uploadImage(MultipartFile multipartFile) {
throw new BusinessException(ErrorCode.FILE_ERROR);
}
}

public Image getS3Image(String s3ImageUrl) throws IOException {
byte[] imageBytes = s3Uploader.getObjectBytes(s3ImageUrl);
return Image.builder()
.bytes(SdkBytes.fromByteBuffer(ByteBuffer.wrap(imageBytes)))
.build();
}

public List<String> getAllImageUrlsFromS3(String directory) {
return s3Uploader.getAllImageUrlsFromS3(directory);
}
}

Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.keepgoing.keepserver.global.common.S3;

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.ListObjectsV2Request;
import com.amazonaws.services.s3.model.ListObjectsV2Result;
import com.amazonaws.services.s3.model.PutObjectRequest;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -12,6 +15,11 @@
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

@Slf4j
Expand Down Expand Up @@ -39,6 +47,7 @@ private String upload(File uploadFile, String dirName) {
return uploadImageUrl; // 업로드된 파일의 S3 URL 주소 반환
}


private String putS3(File uploadFile, String fileName) {
amazonS3Client.putObject(
new PutObjectRequest(bucket, fileName, uploadFile)
Expand All @@ -54,6 +63,48 @@ private void removeNewFile(File targetFile) {
}
}

public void removeFaceAndroidFile(String s3ImageUrl) {
String key = extractKeyFromUrl(s3ImageUrl);
if (key.startsWith("face-android/")) {
amazonS3Client.deleteObject(bucket, key);
log.info("파일이 삭제되었습니다 " + s3ImageUrl);
} else {
log.warn("버킷에 없는 파일입니다 " + s3ImageUrl);
}
}

private String extractKeyFromUrl(String s3ImageUrl) {
String decodedUrl = URLDecoder.decode(s3ImageUrl, StandardCharsets.UTF_8);
return decodedUrl.split(".com/")[1]; // uploads/ 이후의 경로만 추출
}

public List<String> getAllImageUrlsFromS3(String directory) {
List<String> imageUrls = new ArrayList<>();
ListObjectsV2Request request = new ListObjectsV2Request()
.withBucketName(bucket)
.withPrefix(directory);

ListObjectsV2Result result = amazonS3Client.listObjectsV2(request);

result.getObjectSummaries().forEach(s3Object -> {
String imageUrl = amazonS3Client.getUrl(bucket, s3Object.getKey()).toString();
imageUrls.add(imageUrl);
});

return imageUrls;
}

// S3에서 이미지 데이터를 바이트 배열로 가져오기
public byte[] getObjectBytes(String s3ImageUrl) throws IOException {
String key = extractKeyFromUrl(s3ImageUrl); // URL에서 키 추출
try (InputStream inputStream = amazonS3Client.getObject(new GetObjectRequest(bucket, key)).getObjectContent()) {
return inputStream.readAllBytes();
} catch (IOException e) {
log.error(e.getMessage());
throw new IOException("S3 객체 파일 변환 실패", e);
}
}

private File convert(MultipartFile multipartFile) throws IOException {
File file = new File(Objects.requireNonNull(multipartFile.getOriginalFilename()));
try (FileOutputStream fos = new FileOutputStream(file)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.keepgoing.keepserver.global.util;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.rekognition.RekognitionClient;
import software.amazon.awssdk.services.s3.S3Client;

@Configuration
public class AwsRekognitionClient {

@Value("${cloud.aws.credentials.accessKey}")
private String accessKeyId;

@Value("${cloud.aws.credentials.secretKey}")
private String secretKey;

@Value("${cloud.aws.region.static}")
private String region;

@Bean
public RekognitionClient rekognitionClient() {
AwsBasicCredentials awsCreds = AwsBasicCredentials.create(accessKeyId, secretKey);
return RekognitionClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(awsCreds))
.region(Region.of(region))
.build();
}

@Bean
public S3Client s3Client() {
AwsBasicCredentials awsCreds = AwsBasicCredentials.create(accessKeyId, secretKey);
return S3Client.builder()
.credentialsProvider(StaticCredentialsProvider.create(awsCreds))
.region(Region.of(region))
.build();
}
}

0 comments on commit 461d037

Please sign in to comment.