477 lines
16 KiB
Go
477 lines
16 KiB
Go
package service
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"accounting-app/internal/models"
|
|
"accounting-app/internal/repository"
|
|
)
|
|
|
|
// Classification service errors
|
|
var (
|
|
ErrClassificationRuleNotFound = errors.New("classification rule not found")
|
|
ErrInvalidKeyword = errors.New("keyword cannot be empty")
|
|
ErrInvalidCategoryID = errors.New("invalid category ID")
|
|
ErrInvalidAmountRange = errors.New("min amount cannot be greater than max amount")
|
|
ErrRuleAlreadyExists = errors.New("a rule with this keyword and category already exists")
|
|
)
|
|
|
|
// ClassificationRuleInput represents the input data for creating or updating a classification rule
|
|
type ClassificationRuleInput struct {
|
|
UserID uint `json:"user_id"`
|
|
Keyword string `json:"keyword" binding:"required"`
|
|
CategoryID uint `json:"category_id" binding:"required"`
|
|
MinAmount *float64 `json:"min_amount,omitempty"`
|
|
MaxAmount *float64 `json:"max_amount,omitempty"`
|
|
}
|
|
|
|
// ClassificationSuggestion represents a suggested category with confidence score
|
|
type ClassificationSuggestion struct {
|
|
CategoryID uint `json:"category_id"`
|
|
Category *models.Category `json:"category,omitempty"`
|
|
Confidence float64 `json:"confidence"` // 0.0 to 1.0
|
|
MatchedRule *models.ClassificationRule `json:"matched_rule,omitempty"`
|
|
MatchReason string `json:"match_reason"`
|
|
}
|
|
|
|
// ClassificationService handles business logic for smart classification
|
|
type ClassificationService struct {
|
|
classificationRepo *repository.ClassificationRepository
|
|
categoryRepo *repository.CategoryRepository
|
|
}
|
|
|
|
// NewClassificationService creates a new ClassificationService instance
|
|
func NewClassificationService(
|
|
classificationRepo *repository.ClassificationRepository,
|
|
categoryRepo *repository.CategoryRepository,
|
|
) *ClassificationService {
|
|
return &ClassificationService{
|
|
classificationRepo: classificationRepo,
|
|
categoryRepo: categoryRepo,
|
|
}
|
|
}
|
|
|
|
// CreateRule creates a new classification rule with business logic validation
|
|
func (s *ClassificationService) CreateRule(input ClassificationRuleInput) (*models.ClassificationRule, error) {
|
|
// Validate keyword
|
|
keyword := strings.TrimSpace(input.Keyword)
|
|
if keyword == "" {
|
|
return nil, ErrInvalidKeyword
|
|
}
|
|
|
|
// Validate category exists
|
|
exists, err := s.categoryRepo.ExistsByID(input.UserID, input.CategoryID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to validate category: %w", err)
|
|
}
|
|
if !exists {
|
|
return nil, ErrInvalidCategoryID
|
|
}
|
|
|
|
// Validate amount range
|
|
if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount {
|
|
return nil, ErrInvalidAmountRange
|
|
}
|
|
|
|
// Check if rule already exists
|
|
exists, err = s.classificationRepo.ExistsByKeywordAndCategory(input.UserID, keyword, input.CategoryID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check rule existence: %w", err)
|
|
}
|
|
if exists {
|
|
return nil, ErrRuleAlreadyExists
|
|
}
|
|
|
|
// Create the rule
|
|
rule := &models.ClassificationRule{
|
|
Keyword: keyword,
|
|
CategoryID: input.CategoryID,
|
|
MinAmount: input.MinAmount,
|
|
MaxAmount: input.MaxAmount,
|
|
HitCount: 0,
|
|
}
|
|
|
|
if err := s.classificationRepo.Create(rule); err != nil {
|
|
return nil, fmt.Errorf("failed to create classification rule: %w", err)
|
|
}
|
|
|
|
// Load the category relationship
|
|
rule, err = s.classificationRepo.GetByID(input.UserID, rule.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load created rule: %w", err)
|
|
}
|
|
|
|
return rule, nil
|
|
}
|
|
|
|
// GetRule retrieves a classification rule by ID
|
|
func (s *ClassificationService) GetRule(userID, id uint) (*models.ClassificationRule, error) {
|
|
rule, err := s.classificationRepo.GetByID(userID, id)
|
|
if err != nil {
|
|
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
|
return nil, ErrClassificationRuleNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get classification rule: %w", err)
|
|
}
|
|
return rule, nil
|
|
}
|
|
|
|
// GetAllRules retrieves all classification rules
|
|
func (s *ClassificationService) GetAllRules(userID uint) ([]models.ClassificationRule, error) {
|
|
rules, err := s.classificationRepo.GetAll(userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get classification rules: %w", err)
|
|
}
|
|
return rules, nil
|
|
}
|
|
|
|
// GetRulesByCategory retrieves all classification rules for a specific category
|
|
func (s *ClassificationService) GetRulesByCategory(userID, categoryID uint) ([]models.ClassificationRule, error) {
|
|
// Validate category exists
|
|
exists, err := s.categoryRepo.ExistsByID(userID, categoryID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to validate category: %w", err)
|
|
}
|
|
if !exists {
|
|
return nil, ErrInvalidCategoryID
|
|
}
|
|
|
|
rules, err := s.classificationRepo.GetByCategoryID(userID, categoryID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get classification rules: %w", err)
|
|
}
|
|
return rules, nil
|
|
}
|
|
|
|
// UpdateRule updates an existing classification rule
|
|
func (s *ClassificationService) UpdateRule(userID, id uint, input ClassificationRuleInput) (*models.ClassificationRule, error) {
|
|
// Get existing rule
|
|
rule, err := s.classificationRepo.GetByID(userID, id)
|
|
if err != nil {
|
|
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
|
return nil, ErrClassificationRuleNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get classification rule: %w", err)
|
|
}
|
|
|
|
// Validate keyword
|
|
keyword := strings.TrimSpace(input.Keyword)
|
|
if keyword == "" {
|
|
return nil, ErrInvalidKeyword
|
|
}
|
|
|
|
// Validate category exists
|
|
exists, err := s.categoryRepo.ExistsByID(userID, input.CategoryID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to validate category: %w", err)
|
|
}
|
|
if !exists {
|
|
return nil, ErrInvalidCategoryID
|
|
}
|
|
|
|
// Validate amount range
|
|
if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount {
|
|
return nil, ErrInvalidAmountRange
|
|
}
|
|
|
|
// Update fields
|
|
rule.Keyword = keyword
|
|
rule.CategoryID = input.CategoryID
|
|
rule.MinAmount = input.MinAmount
|
|
rule.MaxAmount = input.MaxAmount
|
|
|
|
if err := s.classificationRepo.Update(rule); err != nil {
|
|
return nil, fmt.Errorf("failed to update classification rule: %w", err)
|
|
}
|
|
|
|
// Reload to get updated category
|
|
rule, err = s.classificationRepo.GetByID(userID, id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to reload classification rule: %w", err)
|
|
}
|
|
|
|
return rule, nil
|
|
}
|
|
|
|
// DeleteRule deletes a classification rule by ID
|
|
func (s *ClassificationService) DeleteRule(userID, id uint) error {
|
|
err := s.classificationRepo.Delete(userID, id)
|
|
if err != nil {
|
|
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
|
return ErrClassificationRuleNotFound
|
|
}
|
|
return fmt.Errorf("failed to delete classification rule: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SuggestCategory suggests categories based on transaction note and amount
|
|
// This is the core smart classification algorithm that runs entirely locally
|
|
// Requirement 2.1.1: Recommend most likely category based on historical data
|
|
// Requirement 2.1.2: Runs completely locally, no external data transmission
|
|
// Requirement 2.1.4: Match based on note keywords and amount range
|
|
func (s *ClassificationService) SuggestCategory(userID uint, note string, amount float64) ([]ClassificationSuggestion, error) {
|
|
if strings.TrimSpace(note) == "" {
|
|
return []ClassificationSuggestion{}, nil
|
|
}
|
|
|
|
// Get all matching rules based on keyword and amount
|
|
matchingRules, err := s.classificationRepo.GetMatchingRules(userID, note, amount)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get matching rules: %w", err)
|
|
}
|
|
|
|
if len(matchingRules) == 0 {
|
|
return []ClassificationSuggestion{}, nil
|
|
}
|
|
|
|
// Calculate confidence scores and build suggestions
|
|
suggestions := make([]ClassificationSuggestion, 0, len(matchingRules))
|
|
|
|
// Find the maximum hit count for normalization
|
|
maxHitCount := 0
|
|
for _, rule := range matchingRules {
|
|
if rule.HitCount > maxHitCount {
|
|
maxHitCount = rule.HitCount
|
|
}
|
|
}
|
|
|
|
for i := range matchingRules {
|
|
rule := &matchingRules[i]
|
|
confidence := s.calculateConfidence(note, amount, rule, maxHitCount)
|
|
|
|
matchReason := s.buildMatchReason(note, amount, rule)
|
|
|
|
suggestion := ClassificationSuggestion{
|
|
CategoryID: rule.CategoryID,
|
|
Category: &rule.Category,
|
|
Confidence: confidence,
|
|
MatchedRule: rule,
|
|
MatchReason: matchReason,
|
|
}
|
|
suggestions = append(suggestions, suggestion)
|
|
}
|
|
|
|
// Sort by confidence (highest first) - already partially sorted by hit_count from DB
|
|
s.sortSuggestionsByConfidence(suggestions)
|
|
|
|
// Deduplicate by category ID, keeping highest confidence
|
|
suggestions = s.deduplicateSuggestions(suggestions)
|
|
|
|
return suggestions, nil
|
|
}
|
|
|
|
// calculateConfidence calculates a confidence score for a rule match
|
|
// The score is based on:
|
|
// 1. Keyword match quality (exact match vs partial match)
|
|
// 2. Amount range match (if specified)
|
|
// 3. Historical hit count (popularity)
|
|
func (s *ClassificationService) calculateConfidence(note string, amount float64, rule *models.ClassificationRule, maxHitCount int) float64 {
|
|
var confidence float64 = 0.0
|
|
|
|
// Base score for keyword match (0.3 - 0.5)
|
|
noteLower := strings.ToLower(note)
|
|
keywordLower := strings.ToLower(rule.Keyword)
|
|
|
|
if noteLower == keywordLower {
|
|
// Exact match
|
|
confidence += 0.5
|
|
} else if strings.Contains(noteLower, keywordLower) {
|
|
// Partial match - score based on keyword length relative to note
|
|
keywordRatio := float64(len(rule.Keyword)) / float64(len(note))
|
|
confidence += 0.3 + (0.2 * keywordRatio)
|
|
}
|
|
|
|
// Amount range match bonus (0.0 - 0.3)
|
|
amountBonus := 0.0
|
|
hasAmountConstraint := rule.MinAmount != nil || rule.MaxAmount != nil
|
|
|
|
if hasAmountConstraint {
|
|
inRange := true
|
|
if rule.MinAmount != nil && amount < *rule.MinAmount {
|
|
inRange = false
|
|
}
|
|
if rule.MaxAmount != nil && amount > *rule.MaxAmount {
|
|
inRange = false
|
|
}
|
|
|
|
if inRange {
|
|
// Calculate how well the amount fits in the range
|
|
if rule.MinAmount != nil && rule.MaxAmount != nil {
|
|
rangeSize := *rule.MaxAmount - *rule.MinAmount
|
|
if rangeSize > 0 {
|
|
// Closer to the middle of the range = higher score
|
|
midPoint := (*rule.MinAmount + *rule.MaxAmount) / 2
|
|
distanceFromMid := abs(amount - midPoint)
|
|
normalizedDistance := distanceFromMid / (rangeSize / 2)
|
|
amountBonus = 0.3 * (1 - normalizedDistance)
|
|
} else {
|
|
amountBonus = 0.3 // Exact amount match
|
|
}
|
|
} else {
|
|
amountBonus = 0.2 // Only one bound specified
|
|
}
|
|
}
|
|
}
|
|
confidence += amountBonus
|
|
|
|
// Historical popularity bonus (0.0 - 0.2)
|
|
if maxHitCount > 0 {
|
|
popularityRatio := float64(rule.HitCount) / float64(maxHitCount)
|
|
confidence += 0.2 * popularityRatio
|
|
}
|
|
|
|
// Cap confidence at 1.0
|
|
if confidence > 1.0 {
|
|
confidence = 1.0
|
|
}
|
|
|
|
return confidence
|
|
}
|
|
|
|
// buildMatchReason builds a human-readable explanation for why this category was suggested
|
|
func (s *ClassificationService) buildMatchReason(note string, amount float64, rule *models.ClassificationRule) string {
|
|
reasons := []string{}
|
|
|
|
// Keyword match reason
|
|
noteLower := strings.ToLower(note)
|
|
keywordLower := strings.ToLower(rule.Keyword)
|
|
|
|
if noteLower == keywordLower {
|
|
reasons = append(reasons, fmt.Sprintf("备注完全匹配关键词'%s'", rule.Keyword))
|
|
} else {
|
|
reasons = append(reasons, fmt.Sprintf("备注包含关键词'%s'", rule.Keyword))
|
|
}
|
|
|
|
// Amount range reason
|
|
if rule.MinAmount != nil && rule.MaxAmount != nil {
|
|
reasons = append(reasons, fmt.Sprintf("金额 %.2f 在范围 %.2f-%.2f 内", amount, *rule.MinAmount, *rule.MaxAmount))
|
|
} else if rule.MinAmount != nil {
|
|
reasons = append(reasons, fmt.Sprintf("金额 %.2f >= %.2f", amount, *rule.MinAmount))
|
|
} else if rule.MaxAmount != nil {
|
|
reasons = append(reasons, fmt.Sprintf("金额 %.2f <= %.2f", amount, *rule.MaxAmount))
|
|
}
|
|
|
|
// Hit count reason
|
|
if rule.HitCount > 0 {
|
|
reasons = append(reasons, fmt.Sprintf("历史匹配 %d 次", rule.HitCount))
|
|
}
|
|
|
|
return strings.Join(reasons, "; ")
|
|
}
|
|
|
|
// sortSuggestionsByConfidence sorts suggestions by confidence in descending order
|
|
func (s *ClassificationService) sortSuggestionsByConfidence(suggestions []ClassificationSuggestion) {
|
|
// Simple bubble sort for small arrays (typically < 10 items)
|
|
n := len(suggestions)
|
|
for i := 0; i < n-1; i++ {
|
|
for j := 0; j < n-i-1; j++ {
|
|
if suggestions[j].Confidence < suggestions[j+1].Confidence {
|
|
suggestions[j], suggestions[j+1] = suggestions[j+1], suggestions[j]
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// deduplicateSuggestions removes duplicate category suggestions, keeping the highest confidence
|
|
func (s *ClassificationService) deduplicateSuggestions(suggestions []ClassificationSuggestion) []ClassificationSuggestion {
|
|
seen := make(map[uint]bool)
|
|
result := make([]ClassificationSuggestion, 0, len(suggestions))
|
|
|
|
for _, suggestion := range suggestions {
|
|
if !seen[suggestion.CategoryID] {
|
|
seen[suggestion.CategoryID] = true
|
|
result = append(result, suggestion)
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// ConfirmSuggestion confirms a classification suggestion, incrementing the hit count
|
|
// This is called when the user accepts a suggested category
|
|
// Requirement 2.1.3: Update local classification model when user confirms/modifies
|
|
func (s *ClassificationService) ConfirmSuggestion(userID, ruleID uint) error {
|
|
err := s.classificationRepo.IncrementHitCount(userID, ruleID)
|
|
if err != nil {
|
|
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
|
return ErrClassificationRuleNotFound
|
|
}
|
|
return fmt.Errorf("failed to confirm suggestion: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LearnFromTransaction creates or updates a classification rule based on a confirmed transaction
|
|
// This allows the system to learn from user behavior
|
|
func (s *ClassificationService) LearnFromTransaction(userID uint, note string, amount float64, categoryID uint) error {
|
|
if strings.TrimSpace(note) == "" {
|
|
return nil // Nothing to learn from empty notes
|
|
}
|
|
|
|
// Extract keywords from the note (simple approach: use the whole note as keyword)
|
|
// In a more sophisticated implementation, we could use NLP to extract key phrases
|
|
keyword := strings.TrimSpace(note)
|
|
|
|
// Check if a rule already exists for this keyword and category
|
|
existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword)
|
|
if err != nil && !errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
|
return fmt.Errorf("failed to check existing rule: %w", err)
|
|
}
|
|
|
|
if existingRule != nil {
|
|
// Rule exists - increment hit count if same category
|
|
if existingRule.CategoryID == categoryID {
|
|
return s.classificationRepo.IncrementHitCount(userID, existingRule.ID)
|
|
}
|
|
// Different category - could create a new rule or update existing
|
|
// For now, we'll create a new rule for the new category
|
|
}
|
|
|
|
// Create a new rule
|
|
input := ClassificationRuleInput{
|
|
UserID: userID,
|
|
Keyword: keyword,
|
|
CategoryID: categoryID,
|
|
}
|
|
|
|
// Set amount range based on the transaction amount (±20% range)
|
|
minAmount := amount * 0.8
|
|
maxAmount := amount * 1.2
|
|
input.MinAmount = &minAmount
|
|
input.MaxAmount = &maxAmount
|
|
|
|
_, err = s.CreateRule(input)
|
|
if err != nil {
|
|
// If rule already exists (same keyword, same category), just increment hit count
|
|
if errors.Is(err, ErrRuleAlreadyExists) {
|
|
existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword)
|
|
if err == nil && existingRule.CategoryID == categoryID {
|
|
return s.classificationRepo.IncrementHitCount(userID, existingRule.ID)
|
|
}
|
|
}
|
|
return fmt.Errorf("failed to learn from transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RuleExists checks if a classification rule exists by ID
|
|
func (s *ClassificationService) RuleExists(userID, id uint) (bool, error) {
|
|
exists, err := s.classificationRepo.ExistsByID(userID, id)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check rule existence: %w", err)
|
|
}
|
|
return exists, nil
|
|
}
|
|
|
|
// abs returns the absolute value of a float64
|
|
func abs(x float64) float64 {
|
|
if x < 0 {
|
|
return -x
|
|
}
|
|
return x
|
|
}
|