bill-server-go/internal/auth/middleware.go
2023-05-05 20:27:33 +08:00

55 lines
1.3 KiB
Go

package auth
import (
"os"
"strconv"
"github.com/gofiber/fiber/v2"
jwtware "github.com/gofiber/jwt/v2"
"github.com/golang-jwt/jwt"
)
// JWT error message.
func jwtError(c *fiber.Ctx, err error) error {
if err.Error() == "Missing or malformed JWT" {
return c.Status(fiber.StatusBadRequest).JSON(&fiber.Map{
"status": "error",
"message": "Missing or malformed JWT!",
})
}
return c.Status(fiber.StatusUnauthorized).JSON(&fiber.Map{
"status": "error",
"message": "Invalid or expired JWT!",
})
}
// Guards a specific endpoint in the API.
func JWTMiddleware() fiber.Handler {
return jwtware.New(jwtware.Config{
ErrorHandler: jwtError,
SigningKey: []byte(os.Getenv("JWT_SECRET")),
SigningMethod: "HS256",
TokenLookup: "cookie:jwt",
})
}
// Gets user data (their ID) from the JWT middleware. Should be executed after calling 'JWTMiddleware()'.
func GetDataFromJWT(c *fiber.Ctx) error {
// Get userID from the previous route.
jwtData := c.Locals("user").(*jwt.Token)
claims := jwtData.Claims.(jwt.MapClaims)
parsedUserID := claims["uid"].(string)
userID, err := strconv.Atoi(parsedUserID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(&fiber.Map{
"status": "fail",
"message": err.Error(),
})
}
// Go to next.
c.Locals("currentUser", userID)
return c.Next()
}