diff --git a/backend/core/migrations/0007_forumanswerdownvote_forumanswerupvote.py b/backend/core/migrations/0007_forumanswerdownvote_forumanswerupvote.py new file mode 100644 index 00000000..a3f93af4 --- /dev/null +++ b/backend/core/migrations/0007_forumanswerdownvote_forumanswerupvote.py @@ -0,0 +1,41 @@ +# Generated by Django 5.1.2 on 2024-11-19 11:34 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0006_alter_forumdownvote_forum_question_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='ForumAnswerDownvote', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('forum_answer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='downvotes', to='core.forumanswer')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ['-created_at'], + 'unique_together': {('user', 'forum_answer')}, + }, + ), + migrations.CreateModel( + name='ForumAnswerUpvote', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('forum_answer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='upvotes', to='core.forumanswer')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + 'ordering': ['-created_at'], + 'unique_together': {('user', 'forum_answer')}, + }, + ), + ] diff --git a/backend/core/models.py b/backend/core/models.py index 4d045bcc..12342013 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -160,4 +160,36 @@ class ForumAnswer(models.Model): created_at = models.DateTimeField(auto_now_add=True) def __str__(self): - return self.answer \ No newline at end of file + return self.answer + +class ForumAnswerUpvote(models.Model): + user = models.ForeignKey(CustomUser, on_delete=models.CASCADE) + forum_answer = models.ForeignKey(ForumAnswer, on_delete=models.CASCADE, related_name='upvotes') + created_at = models.DateTimeField(auto_now_add=True) + + class Meta: + unique_together = ("user", "forum_answer") + ordering = ["-created_at"] + + def save(self, *args, **kwargs): + ForumAnswerDownvote.objects.filter(user=self.user, forum_answer=self.forum_answer).delete() + super().save(*args, **kwargs) + + def __str__(self): + return f"{self.user} upvoted {self.forum_answer}" + +class ForumAnswerDownvote(models.Model): + user = models.ForeignKey(CustomUser, on_delete=models.CASCADE) + forum_answer = models.ForeignKey(ForumAnswer, on_delete=models.CASCADE, related_name='downvotes') + created_at = models.DateTimeField(auto_now_add=True) + + class Meta: + unique_together = ("user", "forum_answer") + ordering = ["-created_at"] + + def save(self, *args, **kwargs): + ForumAnswerUpvote.objects.filter(user=self.user, forum_answer=self.forum_answer).delete() + super().save(*args, **kwargs) + + def __str__(self): + return f"{self.user} downvoted {self.forum_answer}" \ No newline at end of file diff --git a/backend/core/serializers/forum_vote_serializer.py b/backend/core/serializers/forum_vote_serializer.py index 672487d3..47e8d286 100644 --- a/backend/core/serializers/forum_vote_serializer.py +++ b/backend/core/serializers/forum_vote_serializer.py @@ -2,7 +2,7 @@ from faker import Faker from rest_framework import serializers -from ..models import (ForumUpvote, ForumDownvote) +from ..models import (ForumUpvote, ForumDownvote, ForumAnswerUpvote, ForumAnswerDownvote) User = get_user_model() queryset = User.objects.all() @@ -38,3 +38,33 @@ def validate(self, attrs): raise serializers.ValidationError("You have already downvoted this forum question.") return attrs + +class ForumAnswerUpvoteSerializer(serializers.ModelSerializer): + class Meta: + model = ForumAnswerUpvote + fields = ("id", "user", "forum_answer", "created_at") + read_only_fields = ("id", "user", "created_at") + + def validate(self, attrs): + user = self.context["request"].user + forum_answer = attrs["forum_answer"] + + if ForumAnswerUpvote.objects.filter(user=user, forum_answer=forum_answer).exists(): + raise serializers.ValidationError("You have already upvoted this forum answer.") + + return attrs + +class ForumAnswerDownvoteSerializer(serializers.ModelSerializer): + class Meta: + model = ForumAnswerDownvote + fields = ("id", "user", "forum_answer", "created_at") + read_only_fields = ("id", "user", "created_at") + + def validate(self, attrs): + user = self.context["request"].user + forum_answer = attrs["forum_answer"] + + if ForumAnswerDownvote.objects.filter(user=user, forum_answer=forum_answer).exists(): + raise serializers.ValidationError("You have already downvoted this forum answer.") + + return attrs diff --git a/backend/core/serializers/serializers.py b/backend/core/serializers/serializers.py index e129c4ea..822f3d0e 100644 --- a/backend/core/serializers/serializers.py +++ b/backend/core/serializers/serializers.py @@ -3,7 +3,7 @@ from rest_framework import serializers from ..models import (CustomUser, ForumQuestion, Quiz, QuizQuestion, QuizQuestionChoice, RateQuiz, - Tag, ForumBookmark, ForumAnswer, ForumUpvote, ForumDownvote, TakeQuiz) + Tag, ForumBookmark, ForumAnswer, ForumUpvote, ForumDownvote, TakeQuiz, ForumAnswerUpvote, ForumAnswerDownvote) from .forum_vote_serializer import ForumUpvoteSerializer, ForumDownvoteSerializer from .take_quiz_serializer import TakeQuizSerializer @@ -50,13 +50,17 @@ class Meta: class ForumAnswerSerializer(serializers.ModelSerializer): author = UserInfoSerializer(read_only=True) - is_my_answer = serializers.SerializerMethodField() + upvotes_count = serializers.SerializerMethodField() is_upvoted = serializers.SerializerMethodField() + downvotes_count = serializers.SerializerMethodField() is_downvoted = serializers.SerializerMethodField() + + class Meta: model = ForumAnswer - fields = ('id', 'answer', 'author', 'created_at', 'is_my_answer', 'is_upvoted', 'is_downvoted') - read_only_fields = ('author', 'created_at') + fields = ('id', 'answer', 'author', 'created_at', 'upvotes_count', 'is_upvoted', 'downvotes_count', 'is_downvoted') + read_only_fields = ('author', 'created_at', 'upvotes_count', 'is_upvoted', 'downvotes_count', 'is_downvoted') + def get_is_my_answer(self, obj): user = self.context['request'].user @@ -81,6 +85,32 @@ def get_is_downvoted(self, obj): def create(self, validated_data): return ForumAnswer.objects.create(**validated_data) + def update(self, instance, validated_data): + instance.answer = validated_data.get('answer', instance.answer) + instance.save() + return instance + + def get_upvotes_count(self, obj): + return obj.upvotes.count() + + def get_is_upvoted(self, obj): + user = self.context['request'].user + if not user.is_authenticated: + return None + upvote = ForumAnswerUpvote.objects.filter(user=user, forum_answer=obj).first() + return upvote.id if upvote else None + + def get_downvotes_count(self, obj): + return obj.downvotes.count() + + def get_is_downvoted(self, obj): + user = self.context['request'].user + if not user.is_authenticated: + return False + downvote = ForumAnswerDownvote.objects.filter(user=user, forum_answer=obj).first() + return downvote.id if downvote else None + + class ForumQuestionSerializer(serializers.ModelSerializer): tags = TagSerializer(many=True) # For nested representation of tags author = UserInfoSerializer(read_only=True) diff --git a/backend/core/tests/test_forum_upvote_downvote.py b/backend/core/tests/test_forum_upvote_downvote.py index 6ba1573f..ec5a43b4 100644 --- a/backend/core/tests/test_forum_upvote_downvote.py +++ b/backend/core/tests/test_forum_upvote_downvote.py @@ -3,14 +3,13 @@ from django.urls import reverse from rest_framework_simplejwt.tokens import RefreshToken from faker import Faker -from core.models import ForumUpvote, ForumDownvote, ForumQuestion +from core.models import ForumUpvote, ForumDownvote, ForumQuestion, ForumAnswer, ForumAnswerUpvote, ForumAnswerDownvote from django.contrib.auth import get_user_model User = get_user_model() fake = Faker() - -class ForumUpvoteAPITest(APITestCase): +class ForumSetup(APITestCase): def setUp(self): # Create a test user self.user = User.objects.create_user( @@ -26,24 +25,27 @@ def setUp(self): self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {str(refresh.access_token)}') # Create a ForumQuestion - # self.forum_question = ForumQuestion.objects.create( - # title="Test Forum Question", - # question="This is a test question for votes.", - # author=self.user - # ) - self.forum_question_response = self.client.post(reverse('forum-question-list'), { - "title": "Test Forum Question", - "question": "This is a test question for votes.", - "tags": [ - {"name": "Django", "linked_data_id": "123", "description": "A web framework."}, - {"name": "DRF", "linked_data_id": "456", "description": "Django Rest Framework."} - ] - }, format='json').data - + self.forum_question = ForumQuestion.objects.create( + title="Test Forum Question", + question="This is a test question for votes.", + author=self.user + ) self.forum_question = ForumQuestion.objects.get(title='Test Forum Question') # Vote data - self.data = {"forum_question": self.forum_question.id} + self.data = {"forum_question": self.forum_question.id} + + # Create a ForumAnswer + self.forum_answer = ForumAnswer.objects.create( + forum_question=self.forum_question, + author=self.user, + answer="This is a test answer for votes." + ) + response = self.client.get(reverse('forum-question-answers-list', args=[self.forum_question.id])) + self.forum_answer = response.data['results'][0] + + +class ForumUpvoteAPITest(ForumSetup): def test_create_forum_upvote(self): question_response = self.client.get(reverse("forum-question-detail", args=[self.forum_question.id])) @@ -133,30 +135,7 @@ def test_get_list_upvote_pagination(self): self.assertIn("created_at", response.data['results'][0]) -class ForumDownvoteAPITest(APITestCase): - def setUp(self): - # Create a test user - self.user = User.objects.create_user( - username=fake.user_name(), - password="testpassword", - email=fake.email(), - full_name=fake.name() - ) - - # Authenticate the test client - self.client = APIClient() - refresh = RefreshToken.for_user(self.user) - self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {str(refresh.access_token)}') - - # Create a ForumQuestion - self.forum_question = ForumQuestion.objects.create( - title="Test Forum Question", - question="This is a test question for votes.", - author=self.user - ) - - # Vote data - self.data = {"forum_question": self.forum_question.id} +class ForumDownvoteAPITest(ForumSetup): def test_create_forum_downvote(self): """Test creating a forum downvote""" @@ -235,4 +214,68 @@ def test_get_list_downvote_pagination(self): self.assertIn("user", response.data['results'][0]) self.assertIn("forum_question", response.data['results'][0]) self.assertIn("created_at", response.data['results'][0]) - \ No newline at end of file + + +class ForumAnswerUpvoteAPITest(ForumSetup): + + def test_create_forum_answer_upvote(self): + """Test creating a forum answer upvote""" + upvote_count = self.forum_answer["upvotes_count"] + + # Vote data + data = {"forum_answer": self.forum_answer["id"]} + + # Send POST request to create a new vote + response = self.client.post(reverse('forum-answer-upvote-list'), data=data, format='json') + + # Assertions + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertTrue(ForumAnswerUpvote.objects.filter(user=self.user, forum_answer=self.forum_answer["id"]).exists()) + self.assertIn("id", response.data) + self.assertIn("user", response.data) + self.assertIn("forum_answer", response.data) + response = self.client.get(reverse('forum-question-answers-detail', args=[self.forum_question.id, self.forum_answer["id"]])) + self.assertEqual(response.data["upvotes_count"], upvote_count + 1) + + def test_delete_forum_answer_upvote(self): + """Test deleting a forum answer upvote""" + # Create a forum answer and upvote to delete + forum_answer = ForumAnswer.objects.create( + forum_question=self.forum_question, + author=self.user, + answer="This is a test answer for votes." + ) + response = self.client.post(reverse('forum-answer-upvote-list'), {"forum_answer": forum_answer.id}, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + response_question = self.client.get(reverse("forum-question-answers-detail", args=[self.forum_question.id, forum_answer.id])) + upvote_count = response_question.data["upvotes_count"] + # Send DELETE request to remove the upvote + response = self.client.delete(reverse('forum-answer-upvote-detail', args=[response.data["id"]])) + + # Assertions + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + response = self.client.get(reverse("forum-question-answers-detail", args=[self.forum_question.id, forum_answer.id])) + + self.assertEqual(response.data["upvotes_count"], upvote_count - 1) + + def test_downvote_after_upvote(self): + """Test deleting a forum answer upvote""" + # Create a forum answer and upvote to delete + forum_answer = ForumAnswer.objects.create( + forum_question=self.forum_question, + author=self.user, + answer="This is a test answer for votes." + ) + response = self.client.post(reverse('forum-answer-upvote-list'), {"forum_answer": forum_answer.id}, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + response_question = self.client.get(reverse("forum-question-answers-detail", args=[self.forum_question.id, forum_answer.id])) + upvote_count = response_question.data["upvotes_count"] + downvote_count = response_question.data["downvotes_count"] + # Send DELETE request to remove the upvote + response = self.client.post(reverse('forum-answer-downvote-list'), {"forum_answer": forum_answer.id}) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + response_question = self.client.get(reverse("forum-question-answers-detail", args=[self.forum_question.id, forum_answer.id])) + self.assertEqual(response_question.status_code, status.HTTP_200_OK) + self.assertEqual(response_question.data["upvotes_count"], upvote_count - 1) + self.assertEqual(response_question.data["downvotes_count"], downvote_count + 1) diff --git a/backend/core/urls.py b/backend/core/urls.py index 3a018f80..d96198c4 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -4,8 +4,7 @@ from .views.rate_quiz_views import RateQuizViewSet from .views.take_quiz_views import TakeQuizViewSet from .views.forum_bookmark_views import ForumBookmarkViewSet -from .views.forum_vote_views import ForumUpvoteViewSet -from .views.forum_vote_views import ForumDownvoteViewSet +from .views.forum_vote_views import ForumUpvoteViewSet, ForumDownvoteViewSet, ForumAnswerUpvoteViewSet, ForumAnswerDownvoteViewSet from rest_framework.routers import DefaultRouter from rest_framework_nested import routers from .views import views @@ -37,6 +36,8 @@ router.register(r'forum-bookmarks', ForumBookmarkViewSet, basename='forumbookmark') router.register(r'forum-upvote', ForumUpvoteViewSet, basename='forum-upvote') router.register(r'forum-downvote', ForumDownvoteViewSet, basename='forum-downvote') +router.register(r'forum-answer-upvote', ForumAnswerUpvoteViewSet, basename='forum-answer-upvote') +router.register(r'forum-answer-downvote', ForumAnswerDownvoteViewSet, basename='forum-answer-downvote') forum_question_router = routers.NestedDefaultRouter(router, r'forum-questions', lookup='forum_question') forum_question_router.register(r'answers', ForumAnswerViewSet, basename='forum-question-answers') diff --git a/backend/core/views/forum_vote_views.py b/backend/core/views/forum_vote_views.py index 90572d68..7336c743 100644 --- a/backend/core/views/forum_vote_views.py +++ b/backend/core/views/forum_vote_views.py @@ -1,6 +1,6 @@ from rest_framework import viewsets, permissions -from ..models import ForumUpvote, ForumDownvote -from ..serializers.forum_vote_serializer import ForumUpvoteSerializer, ForumDownvoteSerializer +from ..models import ForumUpvote, ForumDownvote, ForumAnswerUpvote, ForumAnswerDownvote +from ..serializers.forum_vote_serializer import ForumUpvoteSerializer, ForumDownvoteSerializer, ForumAnswerUpvoteSerializer, ForumAnswerDownvoteSerializer class ForumUpvoteViewSet(viewsets.ModelViewSet): @@ -26,4 +26,28 @@ def perform_create(self, serializer): def get_queryset(self): # Allow users to see only their own downvotes return self.queryset.filter(user=self.request.user) - \ No newline at end of file + + +class ForumAnswerUpvoteViewSet(viewsets.ModelViewSet): + queryset = ForumAnswerUpvote.objects.all() + serializer_class = ForumAnswerUpvoteSerializer + permission_classes = [permissions.IsAuthenticated] + + def perform_create(self, serializer): + serializer.save(user=self.request.user) + + def get_queryset(self): + # Allow users to see only their own upvotes + return self.queryset.filter(user=self.request.user) + +class ForumAnswerDownvoteViewSet(viewsets.ModelViewSet): + queryset = ForumAnswerDownvote.objects.all() + serializer_class = ForumAnswerDownvoteSerializer + permission_classes = [permissions.IsAuthenticated] + + def perform_create(self, serializer): + serializer.save(user=self.request.user) + + def get_queryset(self): + # Allow users to see only their own downvotes + return self.queryset.filter(user=self.request.user) \ No newline at end of file