package service import ( "errors" "fmt" "strings" "accounting-app/internal/models" "accounting-app/internal/repository" ) // Classification service errors var ( ErrClassificationRuleNotFound = errors.New("classification rule not found") ErrInvalidKeyword = errors.New("keyword cannot be empty") ErrInvalidCategoryID = errors.New("invalid category ID") ErrInvalidAmountRange = errors.New("min amount cannot be greater than max amount") ErrRuleAlreadyExists = errors.New("a rule with this keyword and category already exists") ) // ClassificationRuleInput represents the input data for creating or updating a classification rule type ClassificationRuleInput struct { UserID uint `json:"user_id"` Keyword string `json:"keyword" binding:"required"` CategoryID uint `json:"category_id" binding:"required"` MinAmount *float64 `json:"min_amount,omitempty"` MaxAmount *float64 `json:"max_amount,omitempty"` } // ClassificationSuggestion represents a suggested category with confidence score type ClassificationSuggestion struct { CategoryID uint `json:"category_id"` Category *models.Category `json:"category,omitempty"` Confidence float64 `json:"confidence"` // 0.0 to 1.0 MatchedRule *models.ClassificationRule `json:"matched_rule,omitempty"` MatchReason string `json:"match_reason"` } // ClassificationService handles business logic for smart classification type ClassificationService struct { classificationRepo *repository.ClassificationRepository categoryRepo *repository.CategoryRepository } // NewClassificationService creates a new ClassificationService instance func NewClassificationService( classificationRepo *repository.ClassificationRepository, categoryRepo *repository.CategoryRepository, ) *ClassificationService { return &ClassificationService{ classificationRepo: classificationRepo, categoryRepo: categoryRepo, } } // CreateRule creates a new classification rule with business logic validation func (s *ClassificationService) CreateRule(input ClassificationRuleInput) (*models.ClassificationRule, error) { // Validate keyword keyword := strings.TrimSpace(input.Keyword) if keyword == "" { return nil, ErrInvalidKeyword } // Validate category exists exists, err := s.categoryRepo.ExistsByID(input.UserID, input.CategoryID) if err != nil { return nil, fmt.Errorf("failed to validate category: %w", err) } if !exists { return nil, ErrInvalidCategoryID } // Validate amount range if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount { return nil, ErrInvalidAmountRange } // Check if rule already exists exists, err = s.classificationRepo.ExistsByKeywordAndCategory(input.UserID, keyword, input.CategoryID) if err != nil { return nil, fmt.Errorf("failed to check rule existence: %w", err) } if exists { return nil, ErrRuleAlreadyExists } // Create the rule rule := &models.ClassificationRule{ Keyword: keyword, CategoryID: input.CategoryID, MinAmount: input.MinAmount, MaxAmount: input.MaxAmount, HitCount: 0, } if err := s.classificationRepo.Create(rule); err != nil { return nil, fmt.Errorf("failed to create classification rule: %w", err) } // Load the category relationship rule, err = s.classificationRepo.GetByID(input.UserID, rule.ID) if err != nil { return nil, fmt.Errorf("failed to load created rule: %w", err) } return rule, nil } // GetRule retrieves a classification rule by ID func (s *ClassificationService) GetRule(userID, id uint) (*models.ClassificationRule, error) { rule, err := s.classificationRepo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrClassificationRuleNotFound) { return nil, ErrClassificationRuleNotFound } return nil, fmt.Errorf("failed to get classification rule: %w", err) } return rule, nil } // GetAllRules retrieves all classification rules func (s *ClassificationService) GetAllRules(userID uint) ([]models.ClassificationRule, error) { rules, err := s.classificationRepo.GetAll(userID) if err != nil { return nil, fmt.Errorf("failed to get classification rules: %w", err) } return rules, nil } // GetRulesByCategory retrieves all classification rules for a specific category func (s *ClassificationService) GetRulesByCategory(userID, categoryID uint) ([]models.ClassificationRule, error) { // Validate category exists exists, err := s.categoryRepo.ExistsByID(userID, categoryID) if err != nil { return nil, fmt.Errorf("failed to validate category: %w", err) } if !exists { return nil, ErrInvalidCategoryID } rules, err := s.classificationRepo.GetByCategoryID(userID, categoryID) if err != nil { return nil, fmt.Errorf("failed to get classification rules: %w", err) } return rules, nil } // UpdateRule updates an existing classification rule func (s *ClassificationService) UpdateRule(userID, id uint, input ClassificationRuleInput) (*models.ClassificationRule, error) { // Get existing rule rule, err := s.classificationRepo.GetByID(userID, id) if err != nil { if errors.Is(err, repository.ErrClassificationRuleNotFound) { return nil, ErrClassificationRuleNotFound } return nil, fmt.Errorf("failed to get classification rule: %w", err) } // Validate keyword keyword := strings.TrimSpace(input.Keyword) if keyword == "" { return nil, ErrInvalidKeyword } // Validate category exists exists, err := s.categoryRepo.ExistsByID(userID, input.CategoryID) if err != nil { return nil, fmt.Errorf("failed to validate category: %w", err) } if !exists { return nil, ErrInvalidCategoryID } // Validate amount range if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount { return nil, ErrInvalidAmountRange } // Update fields rule.Keyword = keyword rule.CategoryID = input.CategoryID rule.MinAmount = input.MinAmount rule.MaxAmount = input.MaxAmount if err := s.classificationRepo.Update(rule); err != nil { return nil, fmt.Errorf("failed to update classification rule: %w", err) } // Reload to get updated category rule, err = s.classificationRepo.GetByID(userID, id) if err != nil { return nil, fmt.Errorf("failed to reload classification rule: %w", err) } return rule, nil } // DeleteRule deletes a classification rule by ID func (s *ClassificationService) DeleteRule(userID, id uint) error { err := s.classificationRepo.Delete(userID, id) if err != nil { if errors.Is(err, repository.ErrClassificationRuleNotFound) { return ErrClassificationRuleNotFound } return fmt.Errorf("failed to delete classification rule: %w", err) } return nil } // SuggestCategory suggests categories based on transaction note and amount // This is the core smart classification algorithm that runs entirely locally // Requirement 2.1.1: Recommend most likely category based on historical data // Requirement 2.1.2: Runs completely locally, no external data transmission // Requirement 2.1.4: Match based on note keywords and amount range func (s *ClassificationService) SuggestCategory(userID uint, note string, amount float64) ([]ClassificationSuggestion, error) { if strings.TrimSpace(note) == "" { return []ClassificationSuggestion{}, nil } // Get all matching rules based on keyword and amount matchingRules, err := s.classificationRepo.GetMatchingRules(userID, note, amount) if err != nil { return nil, fmt.Errorf("failed to get matching rules: %w", err) } if len(matchingRules) == 0 { return []ClassificationSuggestion{}, nil } // Calculate confidence scores and build suggestions suggestions := make([]ClassificationSuggestion, 0, len(matchingRules)) // Find the maximum hit count for normalization maxHitCount := 0 for _, rule := range matchingRules { if rule.HitCount > maxHitCount { maxHitCount = rule.HitCount } } for i := range matchingRules { rule := &matchingRules[i] confidence := s.calculateConfidence(note, amount, rule, maxHitCount) matchReason := s.buildMatchReason(note, amount, rule) suggestion := ClassificationSuggestion{ CategoryID: rule.CategoryID, Category: &rule.Category, Confidence: confidence, MatchedRule: rule, MatchReason: matchReason, } suggestions = append(suggestions, suggestion) } // Sort by confidence (highest first) - already partially sorted by hit_count from DB s.sortSuggestionsByConfidence(suggestions) // Deduplicate by category ID, keeping highest confidence suggestions = s.deduplicateSuggestions(suggestions) return suggestions, nil } // calculateConfidence calculates a confidence score for a rule match // The score is based on: // 1. Keyword match quality (exact match vs partial match) // 2. Amount range match (if specified) // 3. Historical hit count (popularity) func (s *ClassificationService) calculateConfidence(note string, amount float64, rule *models.ClassificationRule, maxHitCount int) float64 { var confidence float64 = 0.0 // Base score for keyword match (0.3 - 0.5) noteLower := strings.ToLower(note) keywordLower := strings.ToLower(rule.Keyword) if noteLower == keywordLower { // Exact match confidence += 0.5 } else if strings.Contains(noteLower, keywordLower) { // Partial match - score based on keyword length relative to note keywordRatio := float64(len(rule.Keyword)) / float64(len(note)) confidence += 0.3 + (0.2 * keywordRatio) } // Amount range match bonus (0.0 - 0.3) amountBonus := 0.0 hasAmountConstraint := rule.MinAmount != nil || rule.MaxAmount != nil if hasAmountConstraint { inRange := true if rule.MinAmount != nil && amount < *rule.MinAmount { inRange = false } if rule.MaxAmount != nil && amount > *rule.MaxAmount { inRange = false } if inRange { // Calculate how well the amount fits in the range if rule.MinAmount != nil && rule.MaxAmount != nil { rangeSize := *rule.MaxAmount - *rule.MinAmount if rangeSize > 0 { // Closer to the middle of the range = higher score midPoint := (*rule.MinAmount + *rule.MaxAmount) / 2 distanceFromMid := abs(amount - midPoint) normalizedDistance := distanceFromMid / (rangeSize / 2) amountBonus = 0.3 * (1 - normalizedDistance) } else { amountBonus = 0.3 // Exact amount match } } else { amountBonus = 0.2 // Only one bound specified } } } confidence += amountBonus // Historical popularity bonus (0.0 - 0.2) if maxHitCount > 0 { popularityRatio := float64(rule.HitCount) / float64(maxHitCount) confidence += 0.2 * popularityRatio } // Cap confidence at 1.0 if confidence > 1.0 { confidence = 1.0 } return confidence } // buildMatchReason builds a human-readable explanation for why this category was suggested func (s *ClassificationService) buildMatchReason(note string, amount float64, rule *models.ClassificationRule) string { reasons := []string{} // Keyword match reason noteLower := strings.ToLower(note) keywordLower := strings.ToLower(rule.Keyword) if noteLower == keywordLower { reasons = append(reasons, fmt.Sprintf("备注完全匹配关键词'%s'", rule.Keyword)) } else { reasons = append(reasons, fmt.Sprintf("备注包含关键词'%s'", rule.Keyword)) } // Amount range reason if rule.MinAmount != nil && rule.MaxAmount != nil { reasons = append(reasons, fmt.Sprintf("金额 %.2f 在范围 %.2f-%.2f 内", amount, *rule.MinAmount, *rule.MaxAmount)) } else if rule.MinAmount != nil { reasons = append(reasons, fmt.Sprintf("金额 %.2f >= %.2f", amount, *rule.MinAmount)) } else if rule.MaxAmount != nil { reasons = append(reasons, fmt.Sprintf("金额 %.2f <= %.2f", amount, *rule.MaxAmount)) } // Hit count reason if rule.HitCount > 0 { reasons = append(reasons, fmt.Sprintf("历史匹配 %d 次", rule.HitCount)) } return strings.Join(reasons, "; ") } // sortSuggestionsByConfidence sorts suggestions by confidence in descending order func (s *ClassificationService) sortSuggestionsByConfidence(suggestions []ClassificationSuggestion) { // Simple bubble sort for small arrays (typically < 10 items) n := len(suggestions) for i := 0; i < n-1; i++ { for j := 0; j < n-i-1; j++ { if suggestions[j].Confidence < suggestions[j+1].Confidence { suggestions[j], suggestions[j+1] = suggestions[j+1], suggestions[j] } } } } // deduplicateSuggestions removes duplicate category suggestions, keeping the highest confidence func (s *ClassificationService) deduplicateSuggestions(suggestions []ClassificationSuggestion) []ClassificationSuggestion { seen := make(map[uint]bool) result := make([]ClassificationSuggestion, 0, len(suggestions)) for _, suggestion := range suggestions { if !seen[suggestion.CategoryID] { seen[suggestion.CategoryID] = true result = append(result, suggestion) } } return result } // ConfirmSuggestion confirms a classification suggestion, incrementing the hit count // This is called when the user accepts a suggested category // Requirement 2.1.3: Update local classification model when user confirms/modifies func (s *ClassificationService) ConfirmSuggestion(userID, ruleID uint) error { err := s.classificationRepo.IncrementHitCount(userID, ruleID) if err != nil { if errors.Is(err, repository.ErrClassificationRuleNotFound) { return ErrClassificationRuleNotFound } return fmt.Errorf("failed to confirm suggestion: %w", err) } return nil } // LearnFromTransaction creates or updates a classification rule based on a confirmed transaction // This allows the system to learn from user behavior func (s *ClassificationService) LearnFromTransaction(userID uint, note string, amount float64, categoryID uint) error { if strings.TrimSpace(note) == "" { return nil // Nothing to learn from empty notes } // Extract keywords from the note (simple approach: use the whole note as keyword) // In a more sophisticated implementation, we could use NLP to extract key phrases keyword := strings.TrimSpace(note) // Check if a rule already exists for this keyword and category existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword) if err != nil && !errors.Is(err, repository.ErrClassificationRuleNotFound) { return fmt.Errorf("failed to check existing rule: %w", err) } if existingRule != nil { // Rule exists - increment hit count if same category if existingRule.CategoryID == categoryID { return s.classificationRepo.IncrementHitCount(userID, existingRule.ID) } // Different category - could create a new rule or update existing // For now, we'll create a new rule for the new category } // Create a new rule input := ClassificationRuleInput{ UserID: userID, Keyword: keyword, CategoryID: categoryID, } // Set amount range based on the transaction amount (±20% range) minAmount := amount * 0.8 maxAmount := amount * 1.2 input.MinAmount = &minAmount input.MaxAmount = &maxAmount _, err = s.CreateRule(input) if err != nil { // If rule already exists (same keyword, same category), just increment hit count if errors.Is(err, ErrRuleAlreadyExists) { existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword) if err == nil && existingRule.CategoryID == categoryID { return s.classificationRepo.IncrementHitCount(userID, existingRule.ID) } } return fmt.Errorf("failed to learn from transaction: %w", err) } return nil } // RuleExists checks if a classification rule exists by ID func (s *ClassificationService) RuleExists(userID, id uint) (bool, error) { exists, err := s.classificationRepo.ExistsByID(userID, id) if err != nil { return false, fmt.Errorf("failed to check rule existence: %w", err) } return exists, nil } // abs returns the absolute value of a float64 func abs(x float64) float64 { if x < 0 { return -x } return x }