198 lines
4.9 KiB
Go
198 lines
4.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
|
|
auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
|
|
"github.com/go-chi/chi/v5"
|
|
)
|
|
|
|
const sessionCtx string = "session_id"
|
|
|
|
func (app *application) AuthSessionMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie("Session")
|
|
if err != nil {
|
|
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Session cookie is missing"))
|
|
return
|
|
}
|
|
// parse cookie
|
|
token := cookie.Value
|
|
if token == "" {
|
|
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Empty session token"))
|
|
return
|
|
}
|
|
|
|
valid, userID, err := app.auth.Sessions.CheckSession(r.Context(), token) // should have a different function for this
|
|
if !valid {
|
|
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Invalid session token"))
|
|
return
|
|
}
|
|
// userID comes from the session check
|
|
|
|
// need to select the user from the token validation stuff
|
|
ctx := r.Context()
|
|
|
|
user, err := app.getUser(ctx, userID)
|
|
if err != nil {
|
|
app.unauthorizedErrorResponse(w, r, err)
|
|
return
|
|
}
|
|
ctx = context.WithValue(ctx, userCtx, user)
|
|
|
|
ctx = context.WithValue(ctx, sessionCtx, token)
|
|
|
|
// make sure to add user and role into the context here
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func (app *application) CSRFCheckMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
})
|
|
}
|
|
|
|
func (app *application) Paginate(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// add the page and size to the context
|
|
|
|
})
|
|
}
|
|
|
|
// Also include the stuff like getting the receipt and stuff
|
|
|
|
func (app *application) checkRole(next http.Handler, roleName string) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
user := getUserFromContext(r)
|
|
|
|
allowed, err := app.checkRolePrecedence(r.Context(), user, roleName)
|
|
if err != nil {
|
|
app.internalServerError(w, r, err)
|
|
return
|
|
}
|
|
|
|
if !allowed {
|
|
app.forbiddenResponse(w, r)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// essentially a role checking middleware factory
|
|
func (app *application) CheckRoleMiddleware(roleName string) func(next http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return app.checkRole(next, roleName)
|
|
}
|
|
}
|
|
|
|
func (app *application) CheckUserMatchingMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
urlUserID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64)
|
|
if err != nil {
|
|
app.badRequestResponse(w, r, fmt.Errorf("Invalid url user ID - Not an integer"))
|
|
return
|
|
}
|
|
|
|
user := getUserFromContext(r)
|
|
if int64(urlUserID) != user.ID {
|
|
app.forbiddenResponse(w, r)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (app *application) RateLimiterMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if app.config.rateLimiter.Enabled {
|
|
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
app.internalServerError(w, r, err)
|
|
return
|
|
}
|
|
|
|
if allow, retryAfter := app.rateLimiter.Allow(ip); !allow {
|
|
app.rateLimitExceededResponse(w, r, retryAfter.String())
|
|
return
|
|
}
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (app *application) receiptsContextMiddleware(next http.Handler) http.Handler {
|
|
// add the receipt id to the context? or the receipt class to the context
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
receiptID, err := strconv.ParseInt(chi.URLParam(r, "receiptID"), 10, 64)
|
|
if err != nil {
|
|
app.badRequestResponse(w, r, fmt.Errorf("Invalid url receipt ID - Not an integer"))
|
|
return
|
|
}
|
|
|
|
// need to select the receipt
|
|
ctx := r.Context()
|
|
|
|
receipt, err := app.getReceipt(ctx, receiptID)
|
|
if err != nil {
|
|
app.internalServerError(w, r, err)
|
|
return
|
|
}
|
|
|
|
if receipt.OwnerID != getUserFromContext(r).ID {
|
|
app.forbiddenResponse(w, r)
|
|
return
|
|
}
|
|
|
|
ctx = context.WithValue(ctx, receiptCtx, receipt)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
|
|
}
|
|
|
|
func (app *application) checkRolePrecedence(ctx context.Context, user *auth_storage.User, roleName string) (bool, error) {
|
|
role, err := app.auth.Roles.GetByName(ctx, roleName)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return user.Role.Level >= role.Level, nil
|
|
}
|
|
|
|
func (app *application) checkReceiptOwnership(requiredRole string, next http.HandlerFunc) http.HandlerFunc {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user := getUserFromContext(r)
|
|
receipt := getReceiptFromContext(r)
|
|
|
|
if receipt.OwnerID == user.ID {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
allowed, err := app.checkRolePrecedence(r.Context(), user, requiredRole)
|
|
if err != nil {
|
|
app.internalServerError(w, r, err)
|
|
return
|
|
}
|
|
|
|
if !allowed {
|
|
app.forbiddenResponse(w, r)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|