receipt_indexer/backend/cmd/api/middleware.go
2025-06-23 13:09:41 -04:00

232 lines
6.0 KiB
Go

package main
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
// auth_storage "git.ewellenr.ca/receipt_indexer/backend/internal/storage/auth"
l_context "git.ewellenr.ca/receipt_indexer/backend/internal/context"
"github.com/go-chi/chi/v5"
)
const sessionCtx string = "session_id"
func (app *application) AuthSessionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("Session")
if err != nil {
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Session cookie is missing"))
return
}
// parse cookie
token := cookie.Value
if token == "" {
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Empty session token"))
return
}
valid, userID, err := app.store.Sessions.CheckSession(r.Context(), token) // should have a different function for this
if !valid {
app.unauthorizedErrorResponse(w, r, fmt.Errorf("Invalid session token"))
return
}
// userID comes from the session check
// need to select the user from the token validation stuff
ctx := r.Context()
user, err := app.getUser(ctx, userID)
if err != nil {
app.unauthorizedErrorResponse(w, r, err)
return
}
ctx = context.WithValue(ctx, userCtx, user)
ctx = context.WithValue(ctx, sessionCtx, token)
// make sure to add user and role into the context here
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (app *application) CSRFCheckMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
})
}
func (app *application) Paginate(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// add the page and size to the context
})
}
// Also include the stuff like getting the receipt and stuff
func (app *application) checkRole(next http.Handler, roleName string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := getUserFromContext(r)
allowed, err := app.checkRolePrecedence(r.Context(), user, roleName)
if err != nil {
app.internalServerError(w, r, err)
return
}
if !allowed {
app.forbiddenResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// essentially a role checking middleware factory
func (app *application) CheckRoleMiddleware(roleName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return app.checkRole(next, roleName)
}
}
func (app *application) CheckUserMatchingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
urlUserID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64)
if err != nil {
app.badRequestResponse(w, r, fmt.Errorf("Invalid url user ID - Not an integer"))
return
}
user := getUserFromContext(r)
if int64(urlUserID) != user.ID {
app.forbiddenResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (app *application) RateLimiterMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if app.config.rateLimiter.Enabled {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
app.internalServerError(w, r, err)
return
}
if allow, retryAfter := app.rateLimiter.Allow(ip); !allow {
app.rateLimitExceededResponse(w, r, retryAfter.String())
return
}
}
next.ServeHTTP(w, r)
})
}
func (app *application) receiptContextMiddleware(next http.Handler) http.Handler {
// add the receipt id to the context? or the receipt class to the context
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receiptID, err := strconv.ParseInt(chi.URLParam(r, "receiptID"), 10, 64)
if err != nil {
app.badRequestResponse(w, r, fmt.Errorf("Invalid url receipt ID - Not an integer"))
return
}
// need to select the receipt
ctx := r.Context()
receipt, err := app.getReceipt(ctx, receiptID)
if err != nil {
app.internalServerError(w, r, err)
return
}
if receipt.OwnerID != getUserFromContext(r).ID {
app.forbiddenResponse(w, r)
return
}
ctx = context.WithValue(ctx, receiptCtx, receipt)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (app *application) checkRolePrecedence(ctx context.Context, user *auth_storage.User, roleName string) (bool, error) {
role, err := app.store.Roles.GetByName(ctx, roleName)
if err != nil {
return false, err
}
return user.Role.Level >= role.Level, nil
}
func (app *application) checkReceiptOwnership(requiredRole string, next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := getUserFromContext(r)
receipt := getReceiptFromContext(r)
if receipt.OwnerID == user.ID {
next.ServeHTTP(w, r)
return
}
allowed, err := app.checkRolePrecedence(r.Context(), user, requiredRole)
if err != nil {
app.internalServerError(w, r, err)
return
}
if !allowed {
app.forbiddenResponse(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (app *application) addGroupToContextMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
urlGroupID, err := strconv.ParseInt(chi.URLParam(r, "groupID"), 10, 64)
if err != nil {
app.badRequestResponse(w, r, fmt.Errorf("Invalid url group ID - Not an integer"))
return
}
ctx := r.Context()
group, err := app.getGroup(ctx, urlGroupID)
if err != nil {
app.unauthorizedErrorResponse(w, r, err)
return
}
ctx = context.WithValue(ctx, groupCtx, group)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (app *application) addQueryParamsToContextMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// need to select the user from the token validation stuff
ctx := r.Context()
ctx = context.WithValue(ctx, l_context.QueryParamsCtx, r.URL.Query())
// make sure to add user and role into the context here
next.ServeHTTP(w, r.WithContext(ctx))
})
}