Files
Novault-backend/internal/handler/auth_handler.go

209 lines
5.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"errors"
"fmt"
"net/url"
"accounting-app/internal/config"
"accounting-app/internal/service"
"accounting-app/pkg/api"
"github.com/gin-gonic/gin"
)
type AuthHandler struct {
authService *service.AuthService
gitHubOAuthService *service.GitHubOAuthService
cfg *config.Config
}
func NewAuthHandler(authService *service.AuthService, gitHubOAuthService *service.GitHubOAuthService) *AuthHandler {
return &AuthHandler{authService: authService, gitHubOAuthService: gitHubOAuthService}
}
func NewAuthHandlerWithConfig(authService *service.AuthService, gitHubOAuthService *service.GitHubOAuthService, cfg *config.Config) *AuthHandler {
return &AuthHandler{authService: authService, gitHubOAuthService: gitHubOAuthService, cfg: cfg}
}
type RegisterRequest struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required,min=8"`
Username string `json:"username" binding:"required,min=2,max=50"`
}
type LoginRequest struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
}
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
func handleAuthError(c *gin.Context, err error) {
switch {
case errors.Is(err, service.ErrInvalidCredentials):
api.Unauthorized(c, "Invalid email or password")
case errors.Is(err, service.ErrInvalidEmail):
api.BadRequest(c, "Invalid email format")
case errors.Is(err, service.ErrWeakPassword):
api.BadRequest(c, "Password must be at least 8 characters")
case errors.Is(err, service.ErrUserNotActive):
api.Forbidden(c, "User account is not active")
case errors.Is(err, service.ErrInvalidToken):
api.Unauthorized(c, "Invalid token")
case errors.Is(err, service.ErrTokenExpired):
api.Unauthorized(c, "Token has expired")
case errors.Is(err, service.ErrUserExists):
api.Conflict(c, "User with this email already exists")
default:
api.InternalError(c, "Authentication failed: "+err.Error())
}
}
func (h *AuthHandler) Register(c *gin.Context) {
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
input := service.RegisterInput{Email: req.Email, Password: req.Password, Username: req.Username}
user, tokens, err := h.authService.Register(input)
if err != nil {
handleAuthError(c, err)
return
}
api.Created(c, gin.H{"user": user, "tokens": tokens})
}
func (h *AuthHandler) Login(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
input := service.LoginInput{Email: req.Email, Password: req.Password}
user, tokens, err := h.authService.Login(input)
if err != nil {
handleAuthError(c, err)
return
}
api.Success(c, gin.H{"user": user, "tokens": tokens})
}
func (h *AuthHandler) RefreshToken(c *gin.Context) {
var req RefreshRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
tokens, err := h.authService.RefreshToken(req.RefreshToken)
if err != nil {
handleAuthError(c, err)
return
}
api.Success(c, tokens)
}
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
user, err := h.authService.GetUserByID(userID.(uint))
if err != nil {
api.NotFound(c, "User not found")
return
}
api.Success(c, user)
}
type UpdatePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=8"`
}
func (h *AuthHandler) UpdatePassword(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
var req UpdatePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
err := h.authService.UpdatePassword(userID.(uint), req.OldPassword, req.NewPassword)
if err != nil {
handleAuthError(c, err)
return
}
api.Success(c, gin.H{"message": "Password updated successfully"})
}
func (h *AuthHandler) RegisterRoutes(rg *gin.RouterGroup) {
auth := rg.Group("/auth")
auth.POST("/register", h.Register)
auth.POST("/login", h.Login)
auth.POST("/refresh", h.RefreshToken)
auth.GET("/github", h.GitHubLogin)
auth.GET("/github/callback", h.GitHubCallback)
}
func (h *AuthHandler) RegisterProtectedRoutes(rg *gin.RouterGroup) {
auth := rg.Group("/auth")
auth.GET("/me", h.GetCurrentUser)
auth.POST("/password", h.UpdatePassword)
}
func (h *AuthHandler) GitHubLogin(c *gin.Context) {
if h.gitHubOAuthService == nil {
api.BadRequest(c, "GitHub OAuth is not configured")
return
}
state := c.Query("state")
if state == "" {
state = "default"
}
authURL := h.gitHubOAuthService.GetAuthorizationURL(state)
c.Redirect(302, authURL)
}
func (h *AuthHandler) GitHubCallback(c *gin.Context) {
if h.gitHubOAuthService == nil {
api.BadRequest(c, "GitHub OAuth is not configured")
return
}
frontendURL := "http://localhost:2613"
if h.cfg != nil && h.cfg.FrontendURL != "" {
frontendURL = h.cfg.FrontendURL
}
code := c.Query("code")
if code == "" {
redirectURL := fmt.Sprintf("%s/login?error=missing_code", frontendURL)
c.Redirect(302, redirectURL)
return
}
user, tokens, err := h.gitHubOAuthService.HandleCallback(code)
if err != nil {
fmt.Printf("[Auth] GitHub callback failed: %v\n", err)
redirectURL := fmt.Sprintf("%s/login?error=%s", frontendURL, url.QueryEscape(err.Error()))
c.Redirect(302, redirectURL)
return
}
// 重定向到前端回调页面带上token信息
redirectURL := fmt.Sprintf("%s/auth/github/callback?access_token=%s&refresh_token=%s&user_id=%d",
frontendURL, url.QueryEscape(tokens.AccessToken), url.QueryEscape(tokens.RefreshToken), user.ID)
c.Redirect(302, redirectURL)
}