Skip to content

Commit

Permalink
[NOD-322] Minor api-server refactoring. (#401)
Browse files Browse the repository at this point in the history
* [NOD-322] Make database.DB a function

* [NOD-322] Move context to be the first parameter in all functions

* [NOD-322] Set db to nil on database.Close()

* [NOD-322] Tidy go.mod/go.sum

* [NOD-322] Use http package const + message for StatusInternalServerError
  • Loading branch information
svarogg authored and stasatdaglabs committed Sep 15, 2019
1 parent 369031f commit 502b510
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 35 deletions.
14 changes: 11 additions & 3 deletions apiserver/controllers/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@ package controllers
import (
"encoding/hex"
"fmt"
"net/http"

"github.com/daglabs/btcd/apiserver/database"
"github.com/daglabs/btcd/apiserver/models"
"github.com/daglabs/btcd/apiserver/utils"
"github.com/daglabs/btcd/util/daghash"
"net/http"
)

// GetBlockByHashHandler returns a block by a given hash.
func GetBlockByHashHandler(blockHash string) (interface{}, *utils.HandlerError) {
if bytes, err := hex.DecodeString(blockHash); err != nil || len(bytes) != daghash.HashSize {
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity, fmt.Sprintf("The given block hash is not a hex-encoded %d-byte hash.", daghash.HashSize))
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity,
fmt.Sprintf("The given block hash is not a hex-encoded %d-byte hash.", daghash.HashSize))
}

db, err := database.DB()
if err != nil {
return nil, utils.NewHandlerError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

block := &models.Block{}
database.DB.Where(&models.Block{BlockHash: blockHash}).Preload("AcceptingBlock").First(block)
db.Where(&models.Block{BlockHash: blockHash}).Preload("AcceptingBlock").First(block)
if block.ID == 0 {
return nil, utils.NewHandlerError(http.StatusNotFound, "No block with the given block hash was found.")
}
Expand Down
45 changes: 36 additions & 9 deletions apiserver/controllers/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,31 @@ package controllers
import (
"encoding/hex"
"fmt"
"net/http"

"github.com/daglabs/btcd/apiserver/database"
"github.com/daglabs/btcd/apiserver/models"
"github.com/daglabs/btcd/apiserver/utils"
"github.com/daglabs/btcd/util/daghash"
"github.com/jinzhu/gorm"
"net/http"
)

const maximumGetTransactionsLimit = 1000

// GetTransactionByIDHandler returns a transaction by a given transaction ID.
func GetTransactionByIDHandler(txID string) (interface{}, *utils.HandlerError) {
if bytes, err := hex.DecodeString(txID); err != nil || len(bytes) != daghash.TxIDSize {
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity, fmt.Sprintf("The given txid is not a hex-encoded %d-byte hash.", daghash.TxIDSize))
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity,
fmt.Sprintf("The given txid is not a hex-encoded %d-byte hash.", daghash.TxIDSize))
}

db, err := database.DB()
if err != nil {
return nil, utils.NewHandlerError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

tx := &models.Transaction{}
query := database.DB.Where(&models.Transaction{TransactionID: txID})
query := db.Where(&models.Transaction{TransactionID: txID})
addTxPreloadedFields(query).First(&tx)
if tx.ID == 0 {
return nil, utils.NewHandlerError(http.StatusNotFound, "No transaction with the given txid was found.")
Expand All @@ -30,11 +38,17 @@ func GetTransactionByIDHandler(txID string) (interface{}, *utils.HandlerError) {
// GetTransactionByHashHandler returns a transaction by a given transaction hash.
func GetTransactionByHashHandler(txHash string) (interface{}, *utils.HandlerError) {
if bytes, err := hex.DecodeString(txHash); err != nil || len(bytes) != daghash.HashSize {
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity, fmt.Sprintf("The given txhash is not a hex-encoded %d-byte hash.", daghash.HashSize))
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity,
fmt.Sprintf("The given txhash is not a hex-encoded %d-byte hash.", daghash.HashSize))
}

db, err := database.DB()
if err != nil {
return nil, utils.NewHandlerError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

tx := &models.Transaction{}
query := database.DB.
Where(&models.Transaction{TransactionHash: txHash})
query := db.Where(&models.Transaction{TransactionHash: txHash})
addTxPreloadedFields(query).First(&tx)
if tx.ID == 0 {
return nil, utils.NewHandlerError(http.StatusNotFound, "No transaction with the given txhash was found.")
Expand All @@ -46,10 +60,17 @@ func GetTransactionByHashHandler(txHash string) (interface{}, *utils.HandlerErro
// where the given address is either an input or an output.
func GetTransactionsByAddressHandler(address string, skip uint64, limit uint64) (interface{}, *utils.HandlerError) {
if limit > maximumGetTransactionsLimit {
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity, fmt.Sprintf("The maximum allowed value for the limit is %d", maximumGetTransactionsLimit))
return nil, utils.NewHandlerError(http.StatusUnprocessableEntity,
fmt.Sprintf("The maximum allowed value for the limit is %d", maximumGetTransactionsLimit))
}

db, err := database.DB()
if err != nil {
return nil, utils.NewHandlerError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

txs := []*models.Transaction{}
query := database.DB.
query := db.
Joins("LEFT JOIN `transaction_outputs` ON `transaction_outputs`.`transaction_id` = `transactions`.`id`").
Joins("LEFT JOIN `addresses` AS `out_addresses` ON `out_addresses`.`id` = `transaction_outputs`.`address_id`").
Joins("LEFT JOIN `transaction_inputs` ON `transaction_inputs`.`transaction_id` = `transactions`.`id`").
Expand All @@ -70,12 +91,18 @@ func GetTransactionsByAddressHandler(address string, skip uint64, limit uint64)

// GetUTXOsByAddressHandler searches for all UTXOs that belong to a certain address.
func GetUTXOsByAddressHandler(address string) (interface{}, *utils.HandlerError) {
db, err := database.DB()
if err != nil {
return nil, utils.NewHandlerError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
}

var transactionOutputs []*models.TransactionOutput
database.DB.
db.
Joins("LEFT JOIN `addresses` ON `addresses`.`id` = `transaction_outputs`.`address_id`").
Where("`addresses`.`address` = ? AND `transaction_outputs`.`is_spent` = 0", address).
Preload("Transaction.AcceptingBlock").
Find(&transactionOutputs)

UTXOsResponses := make([]*transactionOutputResponse, len(transactionOutputs))
for i, transactionOutput := range transactionOutputs {
UTXOsResponses[i] = &transactionOutputResponse{
Expand Down
28 changes: 24 additions & 4 deletions apiserver/database/database.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
package database

import (
"errors"
"fmt"
"os"

"github.com/daglabs/btcd/apiserver/config"
"github.com/golang-migrate/migrate/v4/source"
"github.com/jinzhu/gorm"
"os"

"github.com/golang-migrate/migrate/v4"
)

// DB is the API server database.
var DB *gorm.DB
// db is the API server database.
var db *gorm.DB

// DB returns a reference to the database connection
func DB() (*gorm.DB, error) {
if db == nil {
return nil, errors.New("Database is not connected")
}
return db, nil
}

// Connect connects to the database mentioned in
// config variable.
Expand All @@ -26,13 +36,23 @@ func Connect(cfg *config.Config) error {
" the database and start again.")
}

DB, err = gorm.Open("mysql", connectionString)
db, err = gorm.Open("mysql", connectionString)
if err != nil {
return err
}
return nil
}

// Close closes the connection to the database
func Close() error {
if db == nil {
return nil
}
err := db.Close()
db = nil
return err
}

func buildConnectionString(cfg *config.Config) string {
return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=True",
cfg.DBUser, cfg.DBPassword, cfg.DBAddress, cfg.DBName)
Expand Down
3 changes: 2 additions & 1 deletion apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"

"github.com/daglabs/btcd/apiserver/config"
"github.com/daglabs/btcd/apiserver/database"
"github.com/daglabs/btcd/apiserver/jsonrpc"
Expand All @@ -27,7 +28,7 @@ func main() {
panic(fmt.Errorf("Error connecting to database: %s", err))
}
defer func() {
err := database.DB.Close()
err := database.Close()
if err != nil {
panic(fmt.Errorf("Error closing the database: %s", err))
}
Expand Down
26 changes: 15 additions & 11 deletions apiserver/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package server
import (
"encoding/json"
"fmt"
"net/http"
"strconv"

"github.com/daglabs/btcd/apiserver/controllers"
"github.com/daglabs/btcd/apiserver/utils"
"github.com/gorilla/mux"
"net/http"
"strconv"
)

const (
Expand All @@ -24,10 +25,13 @@ const (

const defaultGetTransactionsLimit = 100

func makeHandler(handler func(routeParams map[string]string, queryParams map[string][]string, ctx *utils.APIServerContext) (interface{}, *utils.HandlerError)) func(http.ResponseWriter, *http.Request) {
func makeHandler(
handler func(ctx *utils.APIServerContext, routeParams map[string]string, queryParams map[string][]string) (
interface{}, *utils.HandlerError)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {

ctx := utils.ToAPIServerContext(r.Context())
response, hErr := handler(mux.Vars(r), r.URL.Query(), ctx)
response, hErr := handler(ctx, mux.Vars(r), r.URL.Query())
if hErr != nil {
sendErr(ctx, w, hErr)
return
Expand All @@ -54,7 +58,7 @@ func sendJSONResponse(w http.ResponseWriter, response interface{}) {
}
}

func mainHandler(_ map[string]string, _ map[string][]string, _ *utils.APIServerContext) (interface{}, *utils.HandlerError) {
func mainHandler(_ *utils.APIServerContext, _ map[string]string, _ map[string][]string) (interface{}, *utils.HandlerError) {
return "API server is running", nil
}

Expand Down Expand Up @@ -92,15 +96,15 @@ func addRoutes(router *mux.Router) {
Methods("GET")
}

func getTransactionByIDHandler(routeParams map[string]string, _ map[string][]string, _ *utils.APIServerContext) (interface{}, *utils.HandlerError) {
func getTransactionByIDHandler(_ *utils.APIServerContext, routeParams map[string]string, _ map[string][]string) (interface{}, *utils.HandlerError) {
return controllers.GetTransactionByIDHandler(routeParams[routeParamTxID])
}

func getTransactionByHashHandler(routeParams map[string]string, _ map[string][]string, _ *utils.APIServerContext) (interface{}, *utils.HandlerError) {
func getTransactionByHashHandler(_ *utils.APIServerContext, routeParams map[string]string, _ map[string][]string) (interface{}, *utils.HandlerError) {
return controllers.GetTransactionByHashHandler(routeParams[routeParamTxHash])
}

func getTransactionsByAddressHandler(routeParams map[string]string, queryParams map[string][]string, _ *utils.APIServerContext) (interface{}, *utils.HandlerError) {
func getTransactionsByAddressHandler(_ *utils.APIServerContext, routeParams map[string]string, queryParams map[string][]string) (interface{}, *utils.HandlerError) {
skip := 0
limit := defaultGetTransactionsLimit
if len(queryParams[queryParamSkip]) > 1 {
Expand Down Expand Up @@ -128,14 +132,14 @@ func getTransactionsByAddressHandler(routeParams map[string]string, queryParams
return controllers.GetTransactionsByAddressHandler(routeParams[routeParamAddress], uint64(skip), uint64(limit))
}

func getUTXOsByAddressHandler(routeParams map[string]string, _ map[string][]string, _ *utils.APIServerContext) (interface{}, *utils.HandlerError) {
func getUTXOsByAddressHandler(_ *utils.APIServerContext, routeParams map[string]string, _ map[string][]string) (interface{}, *utils.HandlerError) {
return controllers.GetUTXOsByAddressHandler(routeParams[routeParamAddress])
}

func getBlockByHashHandler(routeParams map[string]string, _ map[string][]string, _ *utils.APIServerContext) (interface{}, *utils.HandlerError) {
func getBlockByHashHandler(_ *utils.APIServerContext, routeParams map[string]string, _ map[string][]string) (interface{}, *utils.HandlerError) {
return controllers.GetBlockByHashHandler(routeParams[routeParamBlockHash])
}

func getFeeEstimatesHandler(_ map[string]string, _ map[string][]string, _ *utils.APIServerContext) (interface{}, *utils.HandlerError) {
func getFeeEstimatesHandler(_ *utils.APIServerContext, _ map[string]string, _ map[string][]string) (interface{}, *utils.HandlerError) {
return controllers.GetFeeEstimatesHandler()
}
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.12
require (
bou.ke/monkey v1.0.1
github.com/aead/siphash v1.0.1
github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd
github.com/btcsuite/goleveldb v1.0.0
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792
Expand Down
Loading

0 comments on commit 502b510

Please sign in to comment.