Files
Novault-backend/internal/handler/classification_handler.go
2026-01-25 21:59:00 +08:00

331 lines
9.2 KiB
Go

package handler
import (
"errors"
"strconv"
"accounting-app/pkg/api"
"accounting-app/internal/service"
"github.com/gin-gonic/gin"
)
// ClassificationHandler handles HTTP requests for smart classification operations
type ClassificationHandler struct {
classificationService *service.ClassificationService
}
// NewClassificationHandler creates a new ClassificationHandler instance
func NewClassificationHandler(classificationService *service.ClassificationService) *ClassificationHandler {
return &ClassificationHandler{
classificationService: classificationService,
}
}
// SuggestCategoryInput represents the input for category suggestion
type SuggestCategoryInput struct {
Note string `json:"note" binding:"required"`
Amount float64 `json:"amount"`
}
// SuggestCategory handles POST /api/v1/classify/suggest
// Suggests categories based on transaction note and amount
// This is the smart classification feature that runs entirely locally
// Requirement 2.1.1: Recommend most likely category based on historical data
// Requirement 2.1.2: Runs completely locally, no external data transmission
func (h *ClassificationHandler) SuggestCategory(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
var input SuggestCategoryInput
if err := c.ShouldBindJSON(&input); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
suggestions, err := h.classificationService.SuggestCategory(userId.(uint), input.Note, input.Amount)
if err != nil {
api.InternalError(c, "Failed to get category suggestions: "+err.Error())
return
}
api.Success(c, suggestions)
}
// ConfirmSuggestionInput represents the input for confirming a suggestion
type ConfirmSuggestionInput struct {
RuleID uint `json:"rule_id" binding:"required"`
}
// ConfirmSuggestion handles POST /api/v1/classify/confirm
// Confirms a classification suggestion, incrementing the hit count
// Requirement 2.1.3: Update local classification model when user confirms
func (h *ClassificationHandler) ConfirmSuggestion(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
var input ConfirmSuggestionInput
if err := c.ShouldBindJSON(&input); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
err := h.classificationService.ConfirmSuggestion(userId.(uint), input.RuleID)
if err != nil {
if errors.Is(err, service.ErrClassificationRuleNotFound) {
api.NotFound(c, "Classification rule not found")
return
}
api.InternalError(c, "Failed to confirm suggestion: "+err.Error())
return
}
api.Success(c, gin.H{"message": "Suggestion confirmed successfully"})
}
// LearnInput represents the input for learning from a transaction
type LearnInput struct {
Note string `json:"note" binding:"required"`
Amount float64 `json:"amount"`
CategoryID uint `json:"category_id" binding:"required"`
}
// LearnFromTransaction handles POST /api/v1/classify/learn
// Creates or updates a classification rule based on a confirmed transaction
// This allows the system to learn from user behavior
func (h *ClassificationHandler) LearnFromTransaction(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
var input LearnInput
if err := c.ShouldBindJSON(&input); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
err := h.classificationService.LearnFromTransaction(userId.(uint), input.Note, input.Amount, input.CategoryID)
if err != nil {
if errors.Is(err, service.ErrInvalidCategoryID) {
api.BadRequest(c, "Invalid category ID")
return
}
api.InternalError(c, "Failed to learn from transaction: "+err.Error())
return
}
api.Success(c, gin.H{"message": "Learned from transaction successfully"})
}
// CreateRule handles POST /api/v1/classify/rules
// Creates a new classification rule
func (h *ClassificationHandler) CreateRule(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
var input service.ClassificationRuleInput
if err := c.ShouldBindJSON(&input); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
input.UserID = userId.(uint)
rule, err := h.classificationService.CreateRule(input)
if err != nil {
if errors.Is(err, service.ErrInvalidKeyword) {
api.BadRequest(c, "Keyword cannot be empty")
return
}
if errors.Is(err, service.ErrInvalidCategoryID) {
api.BadRequest(c, "Invalid category ID")
return
}
if errors.Is(err, service.ErrInvalidAmountRange) {
api.BadRequest(c, "Min amount cannot be greater than max amount")
return
}
if errors.Is(err, service.ErrRuleAlreadyExists) {
api.Conflict(c, "A rule with this keyword and category already exists")
return
}
api.InternalError(c, "Failed to create classification rule: "+err.Error())
return
}
api.Created(c, rule)
}
// GetRules handles GET /api/v1/classify/rules
// Returns all classification rules, optionally filtered by category
func (h *ClassificationHandler) GetRules(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
categoryIDStr := c.Query("category_id")
if categoryIDStr != "" {
categoryID, err := strconv.ParseUint(categoryIDStr, 10, 32)
if err != nil {
api.BadRequest(c, "Invalid category ID")
return
}
rules, err := h.classificationService.GetRulesByCategory(userId.(uint), uint(categoryID))
if err != nil {
if errors.Is(err, service.ErrInvalidCategoryID) {
api.BadRequest(c, "Invalid category ID")
return
}
api.InternalError(c, "Failed to get classification rules: "+err.Error())
return
}
api.Success(c, rules)
return
}
rules, err := h.classificationService.GetAllRules(userId.(uint))
if err != nil {
api.InternalError(c, "Failed to get classification rules: "+err.Error())
return
}
api.Success(c, rules)
}
// GetRule handles GET /api/v1/classify/rules/:id
// Returns a single classification rule by ID
func (h *ClassificationHandler) GetRule(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
api.BadRequest(c, "Invalid rule ID")
return
}
rule, err := h.classificationService.GetRule(userId.(uint), uint(id))
if err != nil {
if errors.Is(err, service.ErrClassificationRuleNotFound) {
api.NotFound(c, "Classification rule not found")
return
}
api.InternalError(c, "Failed to get classification rule: "+err.Error())
return
}
api.Success(c, rule)
}
// UpdateRule handles PUT /api/v1/classify/rules/:id
// Updates an existing classification rule
func (h *ClassificationHandler) UpdateRule(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
api.BadRequest(c, "Invalid rule ID")
return
}
var input service.ClassificationRuleInput
if err := c.ShouldBindJSON(&input); err != nil {
api.ValidationError(c, "Invalid request body: "+err.Error())
return
}
rule, err := h.classificationService.UpdateRule(userId.(uint), uint(id), input)
if err != nil {
if errors.Is(err, service.ErrClassificationRuleNotFound) {
api.NotFound(c, "Classification rule not found")
return
}
if errors.Is(err, service.ErrInvalidKeyword) {
api.BadRequest(c, "Keyword cannot be empty")
return
}
if errors.Is(err, service.ErrInvalidCategoryID) {
api.BadRequest(c, "Invalid category ID")
return
}
if errors.Is(err, service.ErrInvalidAmountRange) {
api.BadRequest(c, "Min amount cannot be greater than max amount")
return
}
api.InternalError(c, "Failed to update classification rule: "+err.Error())
return
}
api.Success(c, rule)
}
// DeleteRule handles DELETE /api/v1/classify/rules/:id
// Deletes a classification rule by ID
func (h *ClassificationHandler) DeleteRule(c *gin.Context) {
userId, exists := c.Get("user_id")
if !exists {
api.Unauthorized(c, "User not authenticated")
return
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
api.BadRequest(c, "Invalid rule ID")
return
}
err = h.classificationService.DeleteRule(userId.(uint), uint(id))
if err != nil {
if errors.Is(err, service.ErrClassificationRuleNotFound) {
api.NotFound(c, "Classification rule not found")
return
}
api.InternalError(c, "Failed to delete classification rule: "+err.Error())
return
}
api.NoContent(c)
}
// RegisterRoutes registers all classification routes to the given router group
func (h *ClassificationHandler) RegisterRoutes(rg *gin.RouterGroup) {
classify := rg.Group("/classify")
{
// Smart classification suggestion endpoint (Requirement 2.1)
classify.POST("/suggest", h.SuggestCategory)
classify.POST("/confirm", h.ConfirmSuggestion)
classify.POST("/learn", h.LearnFromTransaction)
// Classification rules management
rules := classify.Group("/rules")
{
rules.POST("", h.CreateRule)
rules.GET("", h.GetRules)
rules.GET("/:id", h.GetRule)
rules.PUT("/:id", h.UpdateRule)
rules.DELETE("/:id", h.DeleteRule)
}
}
}