2025-08-21 23:47:15 -04:00

225 lines
5.3 KiB
Go

package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)
type AuthMiddleware struct {
BaseURL string
WebKeySetsUrl string
Realm string
KeySet jwk.Set
Issuer string
}
func GetUserDocIDFromContext(ctx context.Context) (string, error) {
userId, ok := ctx.Value("userDocId").(string)
if !ok {
return "", fmt.Errorf("user identifier not found in request context")
}
return userId, nil
}
func GetUserUUIDFromContext(ctx context.Context) (string, error) {
userUUID, ok := ctx.Value("userUUID").(string)
if !ok {
return "", fmt.Errorf("user UUID not found in request context")
}
return userUUID, nil
}
func NewAuthMiddleware(baseURL string, realm string, hostName string) *AuthMiddleware {
WebKeySetsUrl := fmt.Sprintf("%s/realms/%s/protocol/openid-connect/certs", baseURL, realm)
keySet, err := jwk.Fetch(context.Background(), WebKeySetsUrl)
if err != nil {
fmt.Println(err)
fmt.Println("All authenticated requests will be rejected until the JWKS is available")
}
return &AuthMiddleware{
BaseURL: baseURL,
WebKeySetsUrl: WebKeySetsUrl,
Realm: realm,
KeySet: keySet,
Issuer: fmt.Sprintf("%s/realms/%s", hostName, realm),
}
}
func (kam *AuthMiddleware) RefreshKeySet(interval time.Duration) {
var mu sync.Mutex
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
keySet, err := jwk.Fetch(context.Background(), kam.WebKeySetsUrl)
if err != nil {
fmt.Printf("Error fetching JWKS from %s: \n", kam.WebKeySetsUrl, err)
continue
}
mu.Lock()
kam.KeySet = keySet
mu.Unlock()
fmt.Printf("Refreshed JWKS from %s\n", kam.WebKeySetsUrl)
}
}()
}
func WriteError(w http.ResponseWriter, httpErr []map[string]string) {
w.WriteHeader(http.StatusUnauthorized)
res, err := json.Marshal(httpErr)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
_, err = w.Write(res)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
func (kam *AuthMiddleware) Authenticator(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if kam.KeySet == nil {
httpError := []map[string]string{
{
"error": "key_set_not_available",
"error_description": "Key set is not available, please try again later",
},
}
WriteError(w, httpError)
return
}
ah := r.Header.Get("Authorization")
if ah == "" {
httpError := []map[string]string{
{
"error": "missing_authorization_header",
"error_description": "Missing Authorization header in request",
},
}
WriteError(w, httpError)
return
}
ts := strings.TrimPrefix(ah, "Bearer ")
if ts == ah {
httpError := []map[string]string{
{
"error": "bad_formatted_authorization_header",
"error_description": "Missing bearer prefix in authorization header",
},
}
WriteError(w, httpError)
return
}
token, err := jwt.Parse([]byte(ts), jwt.WithKeySet(kam.KeySet), jwt.WithValidate(true))
if err != nil {
httpError := []map[string]string{
{
"error": "invalid_token",
"error_description": "Invalid token or signature",
},
}
WriteError(w, httpError)
return
}
if token.Issuer() != kam.Issuer {
httpError := []map[string]string{
{
"error": "invalid_token_issuer",
"error_description": "Invalid issuer in token",
},
}
WriteError(w, httpError)
return
}
userId, _ := token.PrivateClaims()["identifier"].(string)
if userId == "" {
httpError := []map[string]string{
{
"error": "token_user_id_not_found",
"error_description": "Token does not contain user identifier",
},
}
WriteError(w, httpError)
return
}
realmAccess, ok := token.PrivateClaims()["realm_access"].(map[string]interface{})
if !ok {
httpError := []map[string]string{
{
"error": "token_realm_access_not_found",
"error_description": "Token does not contain realm access information",
},
}
WriteError(w, httpError)
return
}
ri, ok := realmAccess["roles"].([]interface{})
if !ok {
httpError := []map[string]string{
{
"error": "token_roles_not_found",
"error_description": "Token does not contain roles information",
},
}
WriteError(w, httpError)
return
}
var roles []string
for _, role := range ri {
strRole, ok := role.(string)
if !ok {
httpError := []map[string]string{
{
"error": "token_invalid_role",
"error_description": "Token contains an invalid role format",
},
}
WriteError(w, httpError)
return
}
roles = append(roles, strRole)
}
userUUID := token.Subject()
if userUUID == "" {
httpError := []map[string]string{
{
"error": "token_user_uuid_not_found",
"error_description": "Token does not contain user UUID",
},
}
WriteError(w, httpError)
return
}
ctx := context.WithValue(r.Context(), "userDocId", userId)
ctx = context.WithValue(ctx, "roles", roles)
ctx = context.WithValue(ctx, "userUUID", userUUID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}