Files
tingz/internal/http/middleware.go
2025-10-14 21:20:25 +03:00

139 lines
3.7 KiB
Go

package http
import (
"context"
"crypto/subtle"
"log/slog"
"net/http"
"strings"
"time"
"git.yigid.dev/fyb/tingz/internal/user"
)
type contextKey string
const usernameKey contextKey = "username"
func AdminAuthMiddleware(adminToken string, logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logger.Warn("missing authorization header")
writeJSON(w, http.StatusUnauthorized, map[string]string{
"status": "error",
"error": "missing authorization header",
})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
logger.Warn("invalid authorization header format")
writeJSON(w, http.StatusUnauthorized, map[string]string{
"status": "error",
"error": "invalid authorization header format",
})
return
}
token := parts[1]
if subtle.ConstantTimeCompare([]byte(token), []byte(adminToken)) != 1 {
logger.Warn("invalid admin token")
writeJSON(w, http.StatusUnauthorized, map[string]string{
"status": "error",
"error": "invalid admin token",
})
return
}
next.ServeHTTP(w, r)
})
}
}
func UserAuthMiddleware(userMgr *user.Manager, logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
logger.Warn("missing authorization header")
writeJSON(w, http.StatusUnauthorized, map[string]string{
"status": "error",
"error": "missing authorization header",
})
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
logger.Warn("invalid authorization header format")
writeJSON(w, http.StatusUnauthorized, map[string]string{
"status": "error",
"error": "invalid authorization header format",
})
return
}
token := parts[1]
username, err := userMgr.GetByToken(r.Context(), token)
if err != nil {
logger.Warn("invalid user token", slog.String("error", err.Error()))
writeJSON(w, http.StatusUnauthorized, map[string]string{
"status": "error",
"error": "invalid token",
})
return
}
ctx := context.WithValue(r.Context(), usernameKey, username)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func LoggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(wrapped, r)
duration := time.Since(start)
logger.Info("request completed",
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", wrapped.statusCode),
slog.Duration("duration", duration),
slog.String("remote_addr", r.RemoteAddr))
})
}
}
func MaxBytesMiddleware(maxBytes int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
next.ServeHTTP(w, r)
})
}
}
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func getUsernameFromContext(ctx context.Context) (string, bool) {
username, ok := ctx.Value(usernameKey).(string)
return username, ok
}