diff --git a/api/auth.go b/api/auth.go index c2d09d6..9b2b6a2 100644 --- a/api/auth.go +++ b/api/auth.go @@ -4,7 +4,6 @@ import ( "auxstream/db" "fmt" "net/http" - "strings" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" @@ -27,17 +26,26 @@ func CookieAuthMiddleware(c *gin.Context) { c.Next() } +type AuthForm struct { + Username string `form:"username" binding:"required"` + Password string `form:"password" binding:"required"` +} + func Signup(c *gin.Context) { - username := c.PostForm("username") - password := c.PostForm("password") + var form AuthForm + + if err := c.ShouldBind(&form); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } - pHash, err := hashPassword(password) + pHash, err := hashPassword(form.Password) if err != nil { fmt.Println("password hash failure: ", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to signup user"}) return } - err = db.DAO.CreateUser(c, username, pHash) + err = db.DAO.CreateUser(c, form.Username, pHash) if err != nil { fmt.Println("CreateUser: ", err.Error()) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to signup user"}) @@ -49,22 +57,21 @@ func Signup(c *gin.Context) { func Login(c *gin.Context) { session := sessions.Default(c) - username := c.PostForm("username") - password := c.PostForm("password") - // Validate form input - if strings.Trim(username, "") == " " || strings.Trim(password, " ") == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "parameters can't be empty"}) + var form AuthForm + + if err := c.ShouldBind(&form); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - user, err := db.DAO.GetUserWithUsername(c, username) + user, err := db.DAO.GetUserWithUsername(c, form.Username) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "user not found"}) return } - if !cmpHashString(user.PasswordHash, password) { + if !cmpHashString(user.PasswordHash, form.Password) { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid username or password"}) return } diff --git a/api/track.go b/api/track.go index e5e6325..2385bd7 100644 --- a/api/track.go +++ b/api/track.go @@ -7,7 +7,6 @@ import ( "github.com/gin-gonic/gin" "mime/multipart" "net/http" - "strconv" ) // FetchTracksByArtistHandler fetch tracks by artist (limit results < 100) @@ -25,15 +24,25 @@ func FetchTracksByArtistHandler(c *gin.Context) { } +type FetchTrackQueryParams struct { + PageSize int8 `form:"pagesize" binding:"gte=0"` + PageNum int8 `form:"pagenumber" binding:"gte=1"` +} + // FetchTracksHandler fetch paginated tracks with limit on page size func FetchTracksHandler(c *gin.Context) { - pagesize := c.Query("pagesize") - pagenumber := c.Query("pagenumber") + var reqParams FetchTrackQueryParams - limit, _ := strconv.Atoi(pagesize) - offset, _ := strconv.Atoi(pagenumber) + fmt.Printf("pagesize: %s\npagenum: %s\n ", c.Query("pagesize"), c.Query("pagenumber")) + if err := c.ShouldBindQuery(&reqParams); err != nil { + c.JSON(http.StatusBadRequest, errorResponse(err.Error())) + return + } - tracks, err := db.DAO.GetTracks(c, int32(limit), int32((offset-1)*limit)) + limit := reqParams.PageSize + offset := (reqParams.PageNum - 1) * reqParams.PageSize + + tracks, err := db.DAO.GetTracks(c, limit, offset) if err != nil { c.JSON(http.StatusInternalServerError, errorResponse(err.Error())) return @@ -45,16 +54,24 @@ func FetchTracksHandler(c *gin.Context) { } +type AddTrackForm struct { + Title string `form:"title" binding:"required"` + ArtistId int `form:"artist_id" binding:"required"` + Audio *multipart.FileHeader `form:"audio" binding:"required"` +} + // AddTrackHandler add track to the system func AddTrackHandler(c *gin.Context) { - form, _ := c.MultipartForm() - trackTittle := form.Value["title"][0] - trackArtistId, err := strconv.Atoi(form.Value["artist_id"][0]) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, errorResponse("artistId should be a valid number")) + var reqForm AddTrackForm + if err := c.ShouldBind(&reqForm); err != nil { + c.JSON(http.StatusBadRequest, errorResponse(err.Error())) return } - file := form.File["audio"][0] + + trackTittle := reqForm.Title + trackArtistId := reqForm.ArtistId + file := reqForm.Audio + if file.Size <= 0 { c.AbortWithStatusJSON(http.StatusBadRequest, errorResponse("audio for track not found")) return @@ -81,18 +98,22 @@ func AddTrackHandler(c *gin.Context) { }) } +type BulkTrackUploadForm struct { + Titles []string `form:"track_titles" binding:"required"` + Files []*multipart.FileHeader `form:"track_files" binding:"required"` + ArtistId int `form:"artist_id" binding:"required"` +} + // BulkTrackUploadHandler enables bulk track uploads func BulkTrackUploadHandler(c *gin.Context) { - form, _ := c.MultipartForm() - titles := form.Value["track_title"] - files := form.File["track_files"] - artistId, err := strconv.Atoi(form.Value["artist_id"][0]) + var reqForm BulkTrackUploadForm - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, errorResponse(fmt.Sprintf("invalid artist_id value: %s", err.Error()))) + if err := c.ShouldBind(&reqForm); err != nil { + c.JSON(http.StatusBadRequest, errorResponse(err.Error())) + return } - fileNames, err := processFiles(files) + fileNames, err := processFiles(reqForm.Files) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, errorResponse(fmt.Sprintf("audio upload failed: %s", err.Error()))) @@ -104,11 +125,11 @@ func BulkTrackUploadHandler(c *gin.Context) { // filter tracks that failed to upload for idx, fileName := range fileNames { if fileName != "" { - trackTitles = append(trackTitles, titles[idx]) + trackTitles = append(trackTitles, reqForm.Titles[idx]) filteredFileNames = append(filteredFileNames, fileName) } } - rows, err := db.DAO.BulkCreateTracks(c, trackTitles, artistId, filteredFileNames) + rows, err := db.DAO.BulkCreateTracks(c, trackTitles, reqForm.ArtistId, filteredFileNames) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, errorResponse(fmt.Sprintf("audio upload failed: %s", err.Error()))) diff --git a/app.env b/app.env index db6e292..7b3afaa 100644 --- a/app.env +++ b/app.env @@ -1,5 +1,5 @@ DATABASE_URL=postgresql://postgres:secret@127.0.0.1:5432/auxstreamdb?sslmode=disable SESSION_STRING=test_random_string -GIN_MODE=release +GIN_MODE=debug PORT=5009 Addr=127.0.0.1 \ No newline at end of file diff --git a/db/db.go b/db/db.go index 89dd44a..bc5d96b 100644 --- a/db/db.go +++ b/db/db.go @@ -111,7 +111,7 @@ func (dao *DataBaseAccessObject) SearchTrackByArtist(ctx context.Context, return GetTrackByArtist(ctx, artist) } -func (dao *DataBaseAccessObject) GetTracks(ctx context.Context, limit int32, offset int32) (tracks []*Track, err error) { +func (dao *DataBaseAccessObject) GetTracks(ctx context.Context, limit int8, offset int8) (tracks []*Track, err error) { return GetTracks(ctx, limit, offset) } diff --git a/db/models.go b/db/models.go index 74af365..f1957fd 100644 --- a/db/models.go +++ b/db/models.go @@ -2,6 +2,7 @@ package db import ( "context" + "fmt" "github.com/jackc/pgx/v5" "time" ) @@ -61,7 +62,8 @@ func (artist *Artist) Commit(ctx context.Context) (err error) { return } -func GetTracks(ctx context.Context, limit int32, offset int32) (tracks []*Track, err error) { +func GetTracks(ctx context.Context, limit int8, offset int8) (tracks []*Track, err error) { + fmt.Printf("GetTracks: limit: %d, offset: %d\n", limit, offset) tracks = []*Track{} stmt := `SELECT id, title, artist_id, file, created_at FROM auxstream.tracks diff --git a/tests/api_test.go b/tests/api_test.go index d71fa8f..f9789e2 100644 --- a/tests/api_test.go +++ b/tests/api_test.go @@ -135,7 +135,7 @@ func TestHTTPTrackUploadBatch(t *testing.T) { formData := url.Values{} formData.Add("artist_id", strconv.Itoa(artistId)) for i := 0; i < testRecordCnt; i++ { - formData.Add("track_title", fmt.Sprintf("#%d", i)) + formData.Add("track_titles", fmt.Sprintf("#%d", i)) } post, err := req.Post(tserver.URL+"/upload_batch_track", formData, trackFiles) @@ -155,7 +155,7 @@ func TestHTTPFetchTracks(t *testing.T) { mockConn.ExpectQuery(` SELECT id, title, artist_id, file, created_at `). - WithArgs(int32(2), int32(0)). + WithArgs(int8(2), int8(0)). WillReturnRows(pgxmock.NewRows(columns). AddRow(1, "Title", 1, "Test file", time.Now()). AddRow(1, "Title", 1, "Test file", time.Now()).