package auth import ( "os" "strconv" "github.com/gofiber/fiber/v2" jwtware "github.com/gofiber/jwt/v2" "github.com/golang-jwt/jwt" ) // JWT error message. func jwtError(c *fiber.Ctx, err error) error { if err.Error() == "Missing or malformed JWT" { return c.Status(fiber.StatusBadRequest).JSON(&fiber.Map{ "status": "error", "message": "Missing or malformed JWT!", }) } return c.Status(fiber.StatusUnauthorized).JSON(&fiber.Map{ "status": "error", "message": "Invalid or expired JWT!", }) } // Guards a specific endpoint in the API. func JWTMiddleware() fiber.Handler { return jwtware.New(jwtware.Config{ ErrorHandler: jwtError, SigningKey: []byte(os.Getenv("JWT_SECRET")), SigningMethod: "HS256", TokenLookup: "cookie:jwt", }) } // Gets user data (their ID) from the JWT middleware. Should be executed after calling 'JWTMiddleware()'. func GetDataFromJWT(c *fiber.Ctx) error { // Get userID from the previous route. jwtData := c.Locals("user").(*jwt.Token) claims := jwtData.Claims.(jwt.MapClaims) parsedUserID := claims["uid"].(string) userID, err := strconv.Atoi(parsedUserID) if err != nil { return c.Status(fiber.StatusInternalServerError).JSON(&fiber.Map{ "status": "fail", "message": err.Error(), }) } // Go to next. c.Locals("currentUser", userID) return c.Next() }