Skip to content

Commit

Permalink
feat: 기본 셋팅
Browse files Browse the repository at this point in the history
  • Loading branch information
taewan2002 committed Feb 15, 2024
1 parent ec24311 commit a2a1642
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 161 deletions.
3 changes: 2 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
SERVER_TYPE=local
ROOT_PATH=
DB_URL=localhost
DB_URL=localhost
HOUSE_REC_URL=https://sarabwayu5.hackathon.sparcs.net/
3 changes: 2 additions & 1 deletion .env-prod
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
SERVER_TYPE=prod
ROOT_PATH=/api
DB_URL=mysql-container
DB_URL=mysql-container
HOUSE_REC_URL=https://sarabwayu5.hackathon.sparcs.net/
91 changes: 0 additions & 91 deletions app/core/ai.py

This file was deleted.

1 change: 1 addition & 0 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class Settings(BaseSettings):
SERVER_TYPE: str
ROOT_PATH: str
DB_URL: str
HOUSE_REC_URL: str

class Config:
env_file = ".env"
Expand Down
10 changes: 9 additions & 1 deletion app/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@ async def get_current_user(
headers={"WWW-Authenticate": "Bearer"},
)

payload = jwt.decode(token, "sarabwayu", algorithms=["HS256"])
try:
payload = jwt.decode(token, "sarabwayu", algorithms=["HS256"])
except:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
headers={"WWW-Authenticate": "Bearer"},
)

nickname: str = payload.get("sub")
user = db.query(User).filter(User.nickname == nickname).first()

Expand Down
16 changes: 15 additions & 1 deletion app/db/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from sqlalchemy import Column, Integer, Text, ForeignKey, String, Boolean, DateTime, func, JSON, Date, FLOAT
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
import pytz

def get_now():
return datetime.now(pytz.timezone('Asia/Seoul'))

Base = declarative_base()

class User(Base):
__tablename__ = 'User'

id = Column(Integer, primary_key=True)
nickname = Column(String(50), index=True, nullable=False)
hashed_password = Column(String(100), nullable=False)
Expand Down Expand Up @@ -40,14 +46,22 @@ class House(Base):

class Recommendation(Base):
__tablename__ = 'Recommendation'

id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('User.id'))
house_id = Column(Integer, ForeignKey('House.id'))
reason = Column(Text, nullable=False)
is_deleted = Column(Boolean, default=False)
create_date = Column(DateTime, default=func.now())
create_date = Column(DateTime, default=get_now())

class LikedHouse(Base):
__tablename__ = 'LikedHouse'

id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('User.id'))
house_id = Column(Integer, ForeignKey('House.id'))
is_deleted = Column(Boolean, default=False)
create_date = Column(DateTime, default=get_now())

def get_Base():
return Base
2 changes: 1 addition & 1 deletion app/router/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

router = APIRouter(prefix="/chat")

@router.post("/chat", response_model=ApiResponse, tags=["Chat"])
@router.post("/", response_model=ApiResponse, tags=["Chat"])
async def post_chat(
chat_data: Chat,
chat_service: Annotated[ChatService, Depends()]
Expand Down
16 changes: 12 additions & 4 deletions app/router/house.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,23 @@ async def post_house_create(
):
print(house_data.house_info)
return ApiResponse()
@router.patch("/like/{house_id}", response_model=ApiResponse, tags=["House"])
async def patch_house_like(
house_id: int,
house_service: Annotated[HouseService, Depends()]
):
return ApiResponse(data=await house_service.like(house_id))

@router.get("/recommendation", response_model=ApiResponse, tags=["House"])
@router.get("/recommendation/list/{page}", response_model=ApiResponse, tags=["House"])
async def get_house_recommendation(
page: int,
house_service: Annotated[HouseService, Depends()]
):
return ApiResponse(data=await house_service.recommendation())
return ApiResponse(data=await house_service.recommendation_list(page))

@router.get("/list", response_model=ApiResponse, tags=["House"])
@router.get("/list/{page}", response_model=ApiResponse, tags=["House"])
async def get_house_list(
page: int,
house_service: Annotated[HouseService, Depends()]
):
return ApiResponse(data=await house_service.list())
return ApiResponse(data=await house_service.list(page))
130 changes: 119 additions & 11 deletions app/service/chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from fastapi import Depends
import json

import requests
from fastapi import Depends, HTTPException, status
from sqlalchemy.orm import Session

from app.db.database import get_db, get_current_user
from app.db.models import User
from app.core.config import settings
from app.db.database import get_db, get_current_user, save_db
from app.db.models import User, House, Recommendation
from app.schemas.request import Chat
from app.service.house import HouseRecommender


class ChatService:
Expand All @@ -12,11 +17,114 @@ def __init__(self, db: Session = Depends(get_db), user: User = Depends(get_curre
self.user = user

async def chat(self, chat_data: Chat):
# print(chat_data.person_count)
# print(chat_data.period)
# print(chat_data.identity)
# print(chat_data.car)
# print(chat_data.child)
# print(chat_data.significant)

return chat_data

async def check_format(data):
if data.person_count not in ["1명", "2명", "3명", "4명 이상"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="person_count가 잘못되었습니다."
)
if data.period not in ["1주", "2주", "3주", "4주 이상"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="period가 잘못되었습니다."
)
if data.identity not in ["학생", "직장인", "기타"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="identity가 잘못되었습니다."
)
if data.car not in ["자차", "대중교통"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="car가 잘못되었습니다."
)
if data.child not in ["아이 있음", "아이 없음"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="child가 잘못되었습니다."
)
return data

chat_data = await check_format(chat_data)

persona = {
"person_count": chat_data.person_count,
"period": chat_data.period,
"identity": chat_data.identity,
"car": chat_data.car,
"child": chat_data.child,
"significant": chat_data.significant
}

# 모든 집 데이터 가져오기
all_houses = self.db.query(House).filter(House.is_deleted == False).all()

# 이미 추천된 데이터 제거
recommended_houses = self.db.query(Recommendation).filter(
Recommendation.user_id == self.user.id,
Recommendation.is_deleted == False
).all()
for house in all_houses:
if house.id in [recommended_house.house_id for recommended_house in recommended_houses]:
all_houses.remove(house)

# 추천 알고리즘 실행
house_recommender = HouseRecommender([house.__dict__ for house in all_houses])
recommended_houses = house_recommender.recommend(persona)

# 추천된 데이터 이름 - id 매핑
recommended_map = {}
for house in recommended_houses:
recommended_map[house[1]["aptName"]] = house[1]["id"]

# XAI를 활용한 추천 API 호출
candidates = []
for house in recommended_houses:
house_dict = {}
house = house[1]
house_dict['aptName'] = house['aptName']
house_dict['articleFeatureDescription'] = (house['articleFeatureDescription'] + ' ' + house[
'detailDescription'])[:100]
house_dict['tagList'] = house['tagList']
house_dict['walkTime'] = house['walkTime']
house_dict['studentCountPerTeacher'] = house['studentCountPerTeacher']
house_dict['aptParkingCountPerHousehold'] = house['aptParkingCountPerHousehold']
candidates.append(house_dict)

request_data = {
"user_info": json.dumps(persona, ensure_ascii=False),
"candidates": json.dumps(candidates, ensure_ascii=False)
}

retry_count = 3

while retry_count > 0:
try:
response = requests.post(settings.HOUSE_REC_URL, json=request_data)
rank_section = response.text.split("rank:")[1]
reason_section = rank_section.split("reason:")[1]
rank_data = rank_section.split("reason:")[0]
rank_data = rank_data[rank_data.find("["):rank_data.find("]") + 1]
reason_section = reason_section[reason_section.find("["):reason_section.find("]") + 1]
rank_data = json.loads(rank_data.replace('\\"', '"'))
return_data = json.loads(reason_section.replace('\\"', '"'))
break
except:
retry_count -= 1
print(f"API 호출 시도 중... {retry_count}회 남음")
if retry_count == 0:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"{rank_data}, {reason_section}"
)

for rank in rank_data:
recommendation = Recommendation(
user_id=self.user.id,
house_id=recommended_map[rank],
reason=return_data[rank_data.index(rank)]
)
save_db(recommendation, self.db)

return return_data
Loading

0 comments on commit a2a1642

Please sign in to comment.