This commit is contained in:
2026-01-25 21:59:00 +08:00
parent 7fd537bef3
commit 4cad3f0250
118 changed files with 30473 additions and 0 deletions

View File

@@ -0,0 +1,476 @@
package service
import (
"errors"
"fmt"
"strings"
"accounting-app/internal/models"
"accounting-app/internal/repository"
)
// Classification service errors
var (
ErrClassificationRuleNotFound = errors.New("classification rule not found")
ErrInvalidKeyword = errors.New("keyword cannot be empty")
ErrInvalidCategoryID = errors.New("invalid category ID")
ErrInvalidAmountRange = errors.New("min amount cannot be greater than max amount")
ErrRuleAlreadyExists = errors.New("a rule with this keyword and category already exists")
)
// ClassificationRuleInput represents the input data for creating or updating a classification rule
type ClassificationRuleInput struct {
UserID uint `json:"user_id"`
Keyword string `json:"keyword" binding:"required"`
CategoryID uint `json:"category_id" binding:"required"`
MinAmount *float64 `json:"min_amount,omitempty"`
MaxAmount *float64 `json:"max_amount,omitempty"`
}
// ClassificationSuggestion represents a suggested category with confidence score
type ClassificationSuggestion struct {
CategoryID uint `json:"category_id"`
Category *models.Category `json:"category,omitempty"`
Confidence float64 `json:"confidence"` // 0.0 to 1.0
MatchedRule *models.ClassificationRule `json:"matched_rule,omitempty"`
MatchReason string `json:"match_reason"`
}
// ClassificationService handles business logic for smart classification
type ClassificationService struct {
classificationRepo *repository.ClassificationRepository
categoryRepo *repository.CategoryRepository
}
// NewClassificationService creates a new ClassificationService instance
func NewClassificationService(
classificationRepo *repository.ClassificationRepository,
categoryRepo *repository.CategoryRepository,
) *ClassificationService {
return &ClassificationService{
classificationRepo: classificationRepo,
categoryRepo: categoryRepo,
}
}
// CreateRule creates a new classification rule with business logic validation
func (s *ClassificationService) CreateRule(input ClassificationRuleInput) (*models.ClassificationRule, error) {
// Validate keyword
keyword := strings.TrimSpace(input.Keyword)
if keyword == "" {
return nil, ErrInvalidKeyword
}
// Validate category exists
exists, err := s.categoryRepo.ExistsByID(input.UserID, input.CategoryID)
if err != nil {
return nil, fmt.Errorf("failed to validate category: %w", err)
}
if !exists {
return nil, ErrInvalidCategoryID
}
// Validate amount range
if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount {
return nil, ErrInvalidAmountRange
}
// Check if rule already exists
exists, err = s.classificationRepo.ExistsByKeywordAndCategory(input.UserID, keyword, input.CategoryID)
if err != nil {
return nil, fmt.Errorf("failed to check rule existence: %w", err)
}
if exists {
return nil, ErrRuleAlreadyExists
}
// Create the rule
rule := &models.ClassificationRule{
Keyword: keyword,
CategoryID: input.CategoryID,
MinAmount: input.MinAmount,
MaxAmount: input.MaxAmount,
HitCount: 0,
}
if err := s.classificationRepo.Create(rule); err != nil {
return nil, fmt.Errorf("failed to create classification rule: %w", err)
}
// Load the category relationship
rule, err = s.classificationRepo.GetByID(input.UserID, rule.ID)
if err != nil {
return nil, fmt.Errorf("failed to load created rule: %w", err)
}
return rule, nil
}
// GetRule retrieves a classification rule by ID
func (s *ClassificationService) GetRule(userID, id uint) (*models.ClassificationRule, error) {
rule, err := s.classificationRepo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
return nil, ErrClassificationRuleNotFound
}
return nil, fmt.Errorf("failed to get classification rule: %w", err)
}
return rule, nil
}
// GetAllRules retrieves all classification rules
func (s *ClassificationService) GetAllRules(userID uint) ([]models.ClassificationRule, error) {
rules, err := s.classificationRepo.GetAll(userID)
if err != nil {
return nil, fmt.Errorf("failed to get classification rules: %w", err)
}
return rules, nil
}
// GetRulesByCategory retrieves all classification rules for a specific category
func (s *ClassificationService) GetRulesByCategory(userID, categoryID uint) ([]models.ClassificationRule, error) {
// Validate category exists
exists, err := s.categoryRepo.ExistsByID(userID, categoryID)
if err != nil {
return nil, fmt.Errorf("failed to validate category: %w", err)
}
if !exists {
return nil, ErrInvalidCategoryID
}
rules, err := s.classificationRepo.GetByCategoryID(userID, categoryID)
if err != nil {
return nil, fmt.Errorf("failed to get classification rules: %w", err)
}
return rules, nil
}
// UpdateRule updates an existing classification rule
func (s *ClassificationService) UpdateRule(userID, id uint, input ClassificationRuleInput) (*models.ClassificationRule, error) {
// Get existing rule
rule, err := s.classificationRepo.GetByID(userID, id)
if err != nil {
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
return nil, ErrClassificationRuleNotFound
}
return nil, fmt.Errorf("failed to get classification rule: %w", err)
}
// Validate keyword
keyword := strings.TrimSpace(input.Keyword)
if keyword == "" {
return nil, ErrInvalidKeyword
}
// Validate category exists
exists, err := s.categoryRepo.ExistsByID(userID, input.CategoryID)
if err != nil {
return nil, fmt.Errorf("failed to validate category: %w", err)
}
if !exists {
return nil, ErrInvalidCategoryID
}
// Validate amount range
if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount {
return nil, ErrInvalidAmountRange
}
// Update fields
rule.Keyword = keyword
rule.CategoryID = input.CategoryID
rule.MinAmount = input.MinAmount
rule.MaxAmount = input.MaxAmount
if err := s.classificationRepo.Update(rule); err != nil {
return nil, fmt.Errorf("failed to update classification rule: %w", err)
}
// Reload to get updated category
rule, err = s.classificationRepo.GetByID(userID, id)
if err != nil {
return nil, fmt.Errorf("failed to reload classification rule: %w", err)
}
return rule, nil
}
// DeleteRule deletes a classification rule by ID
func (s *ClassificationService) DeleteRule(userID, id uint) error {
err := s.classificationRepo.Delete(userID, id)
if err != nil {
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
return ErrClassificationRuleNotFound
}
return fmt.Errorf("failed to delete classification rule: %w", err)
}
return nil
}
// SuggestCategory suggests categories based on transaction note and amount
// This is the core smart classification algorithm that runs entirely locally
// Requirement 2.1.1: Recommend most likely category based on historical data
// Requirement 2.1.2: Runs completely locally, no external data transmission
// Requirement 2.1.4: Match based on note keywords and amount range
func (s *ClassificationService) SuggestCategory(userID uint, note string, amount float64) ([]ClassificationSuggestion, error) {
if strings.TrimSpace(note) == "" {
return []ClassificationSuggestion{}, nil
}
// Get all matching rules based on keyword and amount
matchingRules, err := s.classificationRepo.GetMatchingRules(userID, note, amount)
if err != nil {
return nil, fmt.Errorf("failed to get matching rules: %w", err)
}
if len(matchingRules) == 0 {
return []ClassificationSuggestion{}, nil
}
// Calculate confidence scores and build suggestions
suggestions := make([]ClassificationSuggestion, 0, len(matchingRules))
// Find the maximum hit count for normalization
maxHitCount := 0
for _, rule := range matchingRules {
if rule.HitCount > maxHitCount {
maxHitCount = rule.HitCount
}
}
for i := range matchingRules {
rule := &matchingRules[i]
confidence := s.calculateConfidence(note, amount, rule, maxHitCount)
matchReason := s.buildMatchReason(note, amount, rule)
suggestion := ClassificationSuggestion{
CategoryID: rule.CategoryID,
Category: &rule.Category,
Confidence: confidence,
MatchedRule: rule,
MatchReason: matchReason,
}
suggestions = append(suggestions, suggestion)
}
// Sort by confidence (highest first) - already partially sorted by hit_count from DB
s.sortSuggestionsByConfidence(suggestions)
// Deduplicate by category ID, keeping highest confidence
suggestions = s.deduplicateSuggestions(suggestions)
return suggestions, nil
}
// calculateConfidence calculates a confidence score for a rule match
// The score is based on:
// 1. Keyword match quality (exact match vs partial match)
// 2. Amount range match (if specified)
// 3. Historical hit count (popularity)
func (s *ClassificationService) calculateConfidence(note string, amount float64, rule *models.ClassificationRule, maxHitCount int) float64 {
var confidence float64 = 0.0
// Base score for keyword match (0.3 - 0.5)
noteLower := strings.ToLower(note)
keywordLower := strings.ToLower(rule.Keyword)
if noteLower == keywordLower {
// Exact match
confidence += 0.5
} else if strings.Contains(noteLower, keywordLower) {
// Partial match - score based on keyword length relative to note
keywordRatio := float64(len(rule.Keyword)) / float64(len(note))
confidence += 0.3 + (0.2 * keywordRatio)
}
// Amount range match bonus (0.0 - 0.3)
amountBonus := 0.0
hasAmountConstraint := rule.MinAmount != nil || rule.MaxAmount != nil
if hasAmountConstraint {
inRange := true
if rule.MinAmount != nil && amount < *rule.MinAmount {
inRange = false
}
if rule.MaxAmount != nil && amount > *rule.MaxAmount {
inRange = false
}
if inRange {
// Calculate how well the amount fits in the range
if rule.MinAmount != nil && rule.MaxAmount != nil {
rangeSize := *rule.MaxAmount - *rule.MinAmount
if rangeSize > 0 {
// Closer to the middle of the range = higher score
midPoint := (*rule.MinAmount + *rule.MaxAmount) / 2
distanceFromMid := abs(amount - midPoint)
normalizedDistance := distanceFromMid / (rangeSize / 2)
amountBonus = 0.3 * (1 - normalizedDistance)
} else {
amountBonus = 0.3 // Exact amount match
}
} else {
amountBonus = 0.2 // Only one bound specified
}
}
}
confidence += amountBonus
// Historical popularity bonus (0.0 - 0.2)
if maxHitCount > 0 {
popularityRatio := float64(rule.HitCount) / float64(maxHitCount)
confidence += 0.2 * popularityRatio
}
// Cap confidence at 1.0
if confidence > 1.0 {
confidence = 1.0
}
return confidence
}
// buildMatchReason builds a human-readable explanation for why this category was suggested
func (s *ClassificationService) buildMatchReason(note string, amount float64, rule *models.ClassificationRule) string {
reasons := []string{}
// Keyword match reason
noteLower := strings.ToLower(note)
keywordLower := strings.ToLower(rule.Keyword)
if noteLower == keywordLower {
reasons = append(reasons, fmt.Sprintf("备注完全匹配关键词'%s'", rule.Keyword))
} else {
reasons = append(reasons, fmt.Sprintf("备注包含关键词'%s'", rule.Keyword))
}
// Amount range reason
if rule.MinAmount != nil && rule.MaxAmount != nil {
reasons = append(reasons, fmt.Sprintf("金额 %.2f 在范围 %.2f-%.2f 内", amount, *rule.MinAmount, *rule.MaxAmount))
} else if rule.MinAmount != nil {
reasons = append(reasons, fmt.Sprintf("金额 %.2f >= %.2f", amount, *rule.MinAmount))
} else if rule.MaxAmount != nil {
reasons = append(reasons, fmt.Sprintf("金额 %.2f <= %.2f", amount, *rule.MaxAmount))
}
// Hit count reason
if rule.HitCount > 0 {
reasons = append(reasons, fmt.Sprintf("历史匹配 %d 次", rule.HitCount))
}
return strings.Join(reasons, "; ")
}
// sortSuggestionsByConfidence sorts suggestions by confidence in descending order
func (s *ClassificationService) sortSuggestionsByConfidence(suggestions []ClassificationSuggestion) {
// Simple bubble sort for small arrays (typically < 10 items)
n := len(suggestions)
for i := 0; i < n-1; i++ {
for j := 0; j < n-i-1; j++ {
if suggestions[j].Confidence < suggestions[j+1].Confidence {
suggestions[j], suggestions[j+1] = suggestions[j+1], suggestions[j]
}
}
}
}
// deduplicateSuggestions removes duplicate category suggestions, keeping the highest confidence
func (s *ClassificationService) deduplicateSuggestions(suggestions []ClassificationSuggestion) []ClassificationSuggestion {
seen := make(map[uint]bool)
result := make([]ClassificationSuggestion, 0, len(suggestions))
for _, suggestion := range suggestions {
if !seen[suggestion.CategoryID] {
seen[suggestion.CategoryID] = true
result = append(result, suggestion)
}
}
return result
}
// ConfirmSuggestion confirms a classification suggestion, incrementing the hit count
// This is called when the user accepts a suggested category
// Requirement 2.1.3: Update local classification model when user confirms/modifies
func (s *ClassificationService) ConfirmSuggestion(userID, ruleID uint) error {
err := s.classificationRepo.IncrementHitCount(userID, ruleID)
if err != nil {
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
return ErrClassificationRuleNotFound
}
return fmt.Errorf("failed to confirm suggestion: %w", err)
}
return nil
}
// LearnFromTransaction creates or updates a classification rule based on a confirmed transaction
// This allows the system to learn from user behavior
func (s *ClassificationService) LearnFromTransaction(userID uint, note string, amount float64, categoryID uint) error {
if strings.TrimSpace(note) == "" {
return nil // Nothing to learn from empty notes
}
// Extract keywords from the note (simple approach: use the whole note as keyword)
// In a more sophisticated implementation, we could use NLP to extract key phrases
keyword := strings.TrimSpace(note)
// Check if a rule already exists for this keyword and category
existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword)
if err != nil && !errors.Is(err, repository.ErrClassificationRuleNotFound) {
return fmt.Errorf("failed to check existing rule: %w", err)
}
if existingRule != nil {
// Rule exists - increment hit count if same category
if existingRule.CategoryID == categoryID {
return s.classificationRepo.IncrementHitCount(userID, existingRule.ID)
}
// Different category - could create a new rule or update existing
// For now, we'll create a new rule for the new category
}
// Create a new rule
input := ClassificationRuleInput{
UserID: userID,
Keyword: keyword,
CategoryID: categoryID,
}
// Set amount range based on the transaction amount (±20% range)
minAmount := amount * 0.8
maxAmount := amount * 1.2
input.MinAmount = &minAmount
input.MaxAmount = &maxAmount
_, err = s.CreateRule(input)
if err != nil {
// If rule already exists (same keyword, same category), just increment hit count
if errors.Is(err, ErrRuleAlreadyExists) {
existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword)
if err == nil && existingRule.CategoryID == categoryID {
return s.classificationRepo.IncrementHitCount(userID, existingRule.ID)
}
}
return fmt.Errorf("failed to learn from transaction: %w", err)
}
return nil
}
// RuleExists checks if a classification rule exists by ID
func (s *ClassificationService) RuleExists(userID, id uint) (bool, error) {
exists, err := s.classificationRepo.ExistsByID(userID, id)
if err != nil {
return false, fmt.Errorf("failed to check rule existence: %w", err)
}
return exists, nil
}
// abs returns the absolute value of a float64
func abs(x float64) float64 {
if x < 0 {
return -x
}
return x
}