Skip to content

Commit

Permalink
fix: added SplitPattern function to prevents bugs where the path para…
Browse files Browse the repository at this point in the history
…meter contains the HTTP method
  • Loading branch information
ralvarezdev committed Feb 1, 2025
1 parent b050b0f commit 6f055aa
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 57 deletions.
23 changes: 4 additions & 19 deletions http/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
type (
// Module is the struct for the route module
Module struct {
Path string
Pattern string
Service interface{}
Controller interface{}
BeforeLoadFn func(m *Module)
Expand Down Expand Up @@ -47,17 +47,17 @@ func (m *Module) Create(

// Set the base route
if m.Middlewares != nil {
m.RouterWrapper = baseRouter.NewGroup(m.Path, *m.Middlewares...)
m.RouterWrapper = baseRouter.NewGroup(m.Pattern, *m.Middlewares...)
} else {
m.RouterWrapper = baseRouter.NewGroup(m.Path)
m.RouterWrapper = baseRouter.NewGroup(m.Pattern)
}

// Create the submodules controllers router
router := m.GetRouter()
if m.Submodules != nil {
for i, submodule := range *m.Submodules {
if submodule == nil {
return fmt.Errorf(ErrNilSubmodule, m.Path, i)
return fmt.Errorf(ErrNilSubmodule, m.Pattern, i)
}

if err := submodule.Create(router); err != nil {
Expand All @@ -82,18 +82,3 @@ func (m *Module) Create(
func (m *Module) GetRouter() gonethttproute.RouterWrapper {
return m.RouterWrapper
}

// GetPath is a function that returns the path
func (m *Module) GetPath() string {
return m.Path
}

// GetService is a function that returns the service
func (m *Module) GetService() interface{} {
return m.Service
}

// GetController is a function that returns the controller
func (m *Module) GetController() interface{} {
return m.Controller
}
1 change: 1 addition & 0 deletions http/route/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ import (
var (
ErrNilRouter = errors.New("router cannot be nil")
ErrNilMiddleware = "%s: middleware at index %d cannot be nil"
ErrEmptyPattern = errors.New("pattern cannot be empty")
)
12 changes: 6 additions & 6 deletions http/route/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ func NewLogger(header string, modeLogger gologgermode.Logger) (*Logger, error) {
}

// RegisterRouteGroup registers a route group
func (l *Logger) RegisterRouteGroup(routerPath string, routerGroupPath string) {
func (l *Logger) RegisterRouteGroup(fullPath, pattern string) {
l.logger.Debug(
"registering route group",
"router path: "+routerPath,
"router group path: "+routerGroupPath,
"router path: "+fullPath,
"router group pattern: "+pattern,
)
}

// RegisterRoute registers a route
func (l *Logger) RegisterRoute(routerPath string, routePath string) {
func (l *Logger) RegisterRoute(fullPath, pattern string) {
l.logger.Debug(
"registering route",
"router path: "+routerPath,
"route path: "+routePath,
"router path: "+fullPath,
"route pattern: "+pattern,
)
}
111 changes: 79 additions & 32 deletions http/route/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
goflagsmode "github.com/ralvarezdev/go-flags/mode"
"net/http"
"strings"
)

type (
Expand Down Expand Up @@ -38,41 +39,67 @@ type (
middlewares ...func(next http.Handler) http.Handler,
) *Router
RegisterGroup(router *Router)
Pattern() string
RelativePath() string
FullPath() string
Method() string
}

// Router is the route group struct
Router struct {
middlewares []func(http.Handler) http.Handler
firstHandler http.Handler
mux *http.ServeMux
pattern string
relativePath string
fullPath string
method string
mode *goflagsmode.Flag
logger *Logger
}
)

// AddSlash adds a slash to the path
func AddSlash(path string) string {
if path == "" {
return "/"
} else if path[0] != '/' {
return "/" + path
// SplitPattern returns the method and the path from the pattern
func SplitPattern(pattern string) (string, string, error) {
// Trim the pattern
strings.Trim(pattern, " ")

// Check if the pattern is empty
if pattern == "" {
return "", "", ErrEmptyPattern
}
return path

// Split the pattern by the first space
spaceIndex := 0
for i, char := range pattern {
if char == ' ' {
spaceIndex = i
break
}
}

// Get the method and the path
method := pattern[:spaceIndex]
path := pattern[spaceIndex+1:]

// Trim the path
strings.Trim(path, " ")

return method, path, nil
}

// NewRouter creates a new router
func NewRouter(
path string,
pattern string,
mode *goflagsmode.Flag,
logger *Logger,
middlewares ...func(next http.Handler) http.Handler,
) (*Router, error) {
// Add a slash to the path if it does not have it
path = AddSlash(path)
// Split the method and path from the pattern
method, path, err := SplitPattern(pattern)
if err != nil {
return nil, err
}

// Initialize the multiplexer
mux := http.NewServeMux()
Expand All @@ -91,8 +118,10 @@ func NewRouter(
middlewares,
firstHandler,
mux,
pattern,
path,
path,
method,
mode,
logger,
}, nil
Expand All @@ -110,16 +139,19 @@ func NewBaseRouter(
// NewGroup creates a new router group
func NewGroup(
baseRouter *Router,
relativePath string,
pattern string,
middlewares ...func(next http.Handler) http.Handler,
) (*Router, error) {
// Check if the base router is nil
if baseRouter == nil {
return nil, ErrNilRouter
}

// Add a slash to the path if it does not have it
relativePath = AddSlash(relativePath)
// Split the method and path from the pattern
method, relativePath, err := SplitPattern(pattern)
if err != nil {
return nil, err
}

// Check the base router path
var fullPath string
Expand Down Expand Up @@ -148,8 +180,10 @@ func NewGroup(
firstHandler: firstHandler,
mux: mux,
logger: baseRouter.logger,
pattern: pattern,
relativePath: relativePath,
fullPath: fullPath,
method: method,
mode: baseRouter.mode,
}

Expand All @@ -176,43 +210,46 @@ func (r *Router) GetMiddlewares() *[]func(http.Handler) http.Handler {

// HandleFunc registers a new route with a path, the handler function and the middlewares
func (r *Router) HandleFunc(
relativePath string,
pattern string,
handler http.HandlerFunc,
middlewares ...func(http.Handler) http.Handler,
) {
// Chain the handlers
firstHandler := ChainHandlers(handler, middlewares...)

// Register the route
r.mux.HandleFunc(relativePath, firstHandler.ServeHTTP)
r.mux.HandleFunc(pattern, firstHandler.ServeHTTP)

if r.logger != nil && r.mode != nil && !r.mode.IsProd() {
r.logger.RegisterRoute(r.relativePath, relativePath)
r.logger.RegisterRoute(r.relativePath, pattern)
}
}

// ExactHandleFunc registers a new route with a path, the handler function and the middlewares
func (r *Router) ExactHandleFunc(
relativePath string,
pattern string,
handler http.HandlerFunc,
middlewares ...func(http.Handler) http.Handler,
) {
// Add slash to the path
relativePath = AddSlash(relativePath)
// Split the method and path from the pattern
method, path, err := SplitPattern(pattern)
if err != nil {
panic(err)
}

// Chain the handlers
firstHandler := ChainHandlers(handler, middlewares...)

// Add the '$' wildcard to the end of the path to match the exact path
if relativePath[len(relativePath)-1] == '/' {
relativePath += "{$}"
if path[len(path)-1] == '/' {
path += "{$}"
}

// Register the route
r.mux.HandleFunc(relativePath, firstHandler.ServeHTTP)
r.mux.HandleFunc(method+" "+path, firstHandler.ServeHTTP)

if r.logger != nil && r.mode != nil && !r.mode.IsProd() {
r.logger.RegisterRoute(r.relativePath, relativePath)
r.logger.RegisterRoute(r.relativePath, pattern)
}
}

Expand All @@ -237,28 +274,28 @@ func (r *Router) RegisterExactRoute(
}

// RegisterHandler registers a new route group with a path and a handler function
func (r *Router) RegisterHandler(relativePath string, handler http.Handler) {
// Check if the path contains a trailing slash and remove it
if len(relativePath) > 1 && relativePath[len(relativePath)-1] == '/' {
relativePath = relativePath[:len(relativePath)-1]
func (r *Router) RegisterHandler(pattern string, handler http.Handler) {
// Check if the pattern contains a trailing slash and remove it
if len(pattern) > 0 && pattern[len(pattern)-1] == '/' {
pattern = pattern[:len(pattern)-1]
}

// Register the route group
r.mux.Handle(relativePath+"/", http.StripPrefix(relativePath, handler))
r.mux.Handle(pattern+"/", http.StripPrefix(pattern, handler))

if r.logger != nil && r.mode != nil && !r.mode.IsProd() {
r.logger.RegisterRouteGroup(r.relativePath, relativePath)
r.logger.RegisterRouteGroup(r.fullPath, pattern)
}
}

// RegisterGroup registers a new router group with a path and a router
func (r *Router) RegisterGroup(router *Router) {
r.RegisterHandler(router.RelativePath(), router.mux)
r.RegisterHandler(router.Pattern(), router.mux)
}

// NewGroup creates a new router group with a path
func (r *Router) NewGroup(
relativePath string,
pattern string,
middlewares ...func(next http.Handler) http.Handler,
) *Router {
// Create the middlewares slice
Expand All @@ -271,10 +308,15 @@ func (r *Router) NewGroup(
fullMiddlewares = append(fullMiddlewares, middlewares...)

// Create a new group
newGroup, _ := NewGroup(r, relativePath, fullMiddlewares...)
newGroup, _ := NewGroup(r, pattern, fullMiddlewares...)
return newGroup
}

// Pattern returns the pattern
func (r *Router) Pattern() string {
return r.pattern
}

// RelativePath returns the relative path
func (r *Router) RelativePath() string {
return r.relativePath
Expand All @@ -284,3 +326,8 @@ func (r *Router) RelativePath() string {
func (r *Router) FullPath() string {
return r.fullPath
}

// Method returns the method
func (r *Router) Method() string {
return r.method
}

0 comments on commit 6f055aa

Please sign in to comment.