The main middleware file. Contains a lot
Signed-off-by: Ethan Wellenreiter <ewellenreiter@gmail.com>
This commit is contained in:
parent
a2837b6d82
commit
65b88d08ef
197
backend/cmd/api/middleware.go
Normal file
197
backend/cmd/api/middleware.go
Normal file
@ -0,0 +1,197 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
const sessionCtx string = "session_id"
|
||||
|
||||
func (app *application) AuthSessionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("Session")
|
||||
if err != nil {
|
||||
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Session cookie is missing"))
|
||||
return
|
||||
}
|
||||
// parse cookie
|
||||
token := cookie.Value
|
||||
if token == "" {
|
||||
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Empty session token"))
|
||||
return
|
||||
}
|
||||
|
||||
valid, userID, err := app.auth.Sessions.CheckSession(r.Context(), token) // should have a different function for this
|
||||
if !valid {
|
||||
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Invalid session token"))
|
||||
return
|
||||
}
|
||||
// userID comes from the session check
|
||||
|
||||
// need to select the user from the token validation stuff
|
||||
ctx := r.Context()
|
||||
|
||||
user, err := app.getUser(ctx, userID)
|
||||
if err != nil {
|
||||
app.unauthorizedErrorResponse(w, r, err)
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, userCtx, user)
|
||||
|
||||
ctx = context.WithValue(ctx, sessionCtx, token)
|
||||
|
||||
// make sure to add user and role into the context here
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) CSRFCheckMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) Paginate(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// add the page and size to the context
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
// Also include the stuff like getting the receipt and stuff
|
||||
|
||||
func (app *application) checkRole(next http.Handler, roleName string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
user := getUserFromContext(r)
|
||||
|
||||
allowed, err := app.checkRolePrecedence(r.Context(), user, roleName)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// essentially a role checking middleware factory
|
||||
func (app *application) CheckRoleMiddleware(roleName string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return app.checkRole(next, roleName)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *application) CheckUserMatchingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
urlUserID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64)
|
||||
if err != nil {
|
||||
app.badRequestResponse(w, r, fmt.Errorf("Invalid url user ID - Not an integer"))
|
||||
return
|
||||
}
|
||||
|
||||
user := getUserFromContext(r)
|
||||
if int64(urlUserID) != user.ID {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) RateLimiterMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if app.config.rateLimiter.Enabled {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if allow, retryAfter := app.rateLimiter.Allow(ip); !allow {
|
||||
app.rateLimitExceededResponse(w, r, retryAfter.String())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (app *application) receiptsContextMiddleware(next http.Handler) http.Handler {
|
||||
// add the receipt id to the context? or the receipt class to the context
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
receiptID, err := strconv.ParseInt(chi.URLParam(r, "receiptID"), 10, 64)
|
||||
if err != nil {
|
||||
app.badRequestResponse(w, r, fmt.Errorf("Invalid url receipt ID - Not an integer"))
|
||||
return
|
||||
}
|
||||
|
||||
// need to select the receipt
|
||||
ctx := r.Context()
|
||||
|
||||
receipt, err := app.getReceipt(ctx, receiptID)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if receipt.OwnerID != getUserFromContext(r).ID {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, receiptCtx, receipt)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (app *application) checkRolePrecedence(ctx context.Context, user *auth_storage.User, roleName string) (bool, error) {
|
||||
role, err := app.auth.Roles.GetByName(ctx, roleName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return user.Role.Level >= role.Level, nil
|
||||
}
|
||||
|
||||
func (app *application) checkReceiptOwnership(requiredRole string, next http.HandlerFunc) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := getUserFromContext(r)
|
||||
receipt := getReceiptFromContext(r)
|
||||
|
||||
if receipt.OwnerID == user.ID {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := app.checkRolePrecedence(r.Context(), user, requiredRole)
|
||||
if err != nil {
|
||||
app.internalServerError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
app.forbiddenResponse(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user