init
This commit is contained in:
476
internal/service/classification_service.go
Normal file
476
internal/service/classification_service.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"accounting-app/internal/models"
|
||||
"accounting-app/internal/repository"
|
||||
)
|
||||
|
||||
// Classification service errors
|
||||
var (
|
||||
ErrClassificationRuleNotFound = errors.New("classification rule not found")
|
||||
ErrInvalidKeyword = errors.New("keyword cannot be empty")
|
||||
ErrInvalidCategoryID = errors.New("invalid category ID")
|
||||
ErrInvalidAmountRange = errors.New("min amount cannot be greater than max amount")
|
||||
ErrRuleAlreadyExists = errors.New("a rule with this keyword and category already exists")
|
||||
)
|
||||
|
||||
// ClassificationRuleInput represents the input data for creating or updating a classification rule
|
||||
type ClassificationRuleInput struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Keyword string `json:"keyword" binding:"required"`
|
||||
CategoryID uint `json:"category_id" binding:"required"`
|
||||
MinAmount *float64 `json:"min_amount,omitempty"`
|
||||
MaxAmount *float64 `json:"max_amount,omitempty"`
|
||||
}
|
||||
|
||||
// ClassificationSuggestion represents a suggested category with confidence score
|
||||
type ClassificationSuggestion struct {
|
||||
CategoryID uint `json:"category_id"`
|
||||
Category *models.Category `json:"category,omitempty"`
|
||||
Confidence float64 `json:"confidence"` // 0.0 to 1.0
|
||||
MatchedRule *models.ClassificationRule `json:"matched_rule,omitempty"`
|
||||
MatchReason string `json:"match_reason"`
|
||||
}
|
||||
|
||||
// ClassificationService handles business logic for smart classification
|
||||
type ClassificationService struct {
|
||||
classificationRepo *repository.ClassificationRepository
|
||||
categoryRepo *repository.CategoryRepository
|
||||
}
|
||||
|
||||
// NewClassificationService creates a new ClassificationService instance
|
||||
func NewClassificationService(
|
||||
classificationRepo *repository.ClassificationRepository,
|
||||
categoryRepo *repository.CategoryRepository,
|
||||
) *ClassificationService {
|
||||
return &ClassificationService{
|
||||
classificationRepo: classificationRepo,
|
||||
categoryRepo: categoryRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRule creates a new classification rule with business logic validation
|
||||
func (s *ClassificationService) CreateRule(input ClassificationRuleInput) (*models.ClassificationRule, error) {
|
||||
// Validate keyword
|
||||
keyword := strings.TrimSpace(input.Keyword)
|
||||
if keyword == "" {
|
||||
return nil, ErrInvalidKeyword
|
||||
}
|
||||
|
||||
// Validate category exists
|
||||
exists, err := s.categoryRepo.ExistsByID(input.UserID, input.CategoryID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate category: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil, ErrInvalidCategoryID
|
||||
}
|
||||
|
||||
// Validate amount range
|
||||
if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount {
|
||||
return nil, ErrInvalidAmountRange
|
||||
}
|
||||
|
||||
// Check if rule already exists
|
||||
exists, err = s.classificationRepo.ExistsByKeywordAndCategory(input.UserID, keyword, input.CategoryID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule existence: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrRuleAlreadyExists
|
||||
}
|
||||
|
||||
// Create the rule
|
||||
rule := &models.ClassificationRule{
|
||||
Keyword: keyword,
|
||||
CategoryID: input.CategoryID,
|
||||
MinAmount: input.MinAmount,
|
||||
MaxAmount: input.MaxAmount,
|
||||
HitCount: 0,
|
||||
}
|
||||
|
||||
if err := s.classificationRepo.Create(rule); err != nil {
|
||||
return nil, fmt.Errorf("failed to create classification rule: %w", err)
|
||||
}
|
||||
|
||||
// Load the category relationship
|
||||
rule, err = s.classificationRepo.GetByID(input.UserID, rule.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load created rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// GetRule retrieves a classification rule by ID
|
||||
func (s *ClassificationService) GetRule(userID, id uint) (*models.ClassificationRule, error) {
|
||||
rule, err := s.classificationRepo.GetByID(userID, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
||||
return nil, ErrClassificationRuleNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get classification rule: %w", err)
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// GetAllRules retrieves all classification rules
|
||||
func (s *ClassificationService) GetAllRules(userID uint) ([]models.ClassificationRule, error) {
|
||||
rules, err := s.classificationRepo.GetAll(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get classification rules: %w", err)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// GetRulesByCategory retrieves all classification rules for a specific category
|
||||
func (s *ClassificationService) GetRulesByCategory(userID, categoryID uint) ([]models.ClassificationRule, error) {
|
||||
// Validate category exists
|
||||
exists, err := s.categoryRepo.ExistsByID(userID, categoryID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate category: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil, ErrInvalidCategoryID
|
||||
}
|
||||
|
||||
rules, err := s.classificationRepo.GetByCategoryID(userID, categoryID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get classification rules: %w", err)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// UpdateRule updates an existing classification rule
|
||||
func (s *ClassificationService) UpdateRule(userID, id uint, input ClassificationRuleInput) (*models.ClassificationRule, error) {
|
||||
// Get existing rule
|
||||
rule, err := s.classificationRepo.GetByID(userID, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
||||
return nil, ErrClassificationRuleNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get classification rule: %w", err)
|
||||
}
|
||||
|
||||
// Validate keyword
|
||||
keyword := strings.TrimSpace(input.Keyword)
|
||||
if keyword == "" {
|
||||
return nil, ErrInvalidKeyword
|
||||
}
|
||||
|
||||
// Validate category exists
|
||||
exists, err := s.categoryRepo.ExistsByID(userID, input.CategoryID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate category: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return nil, ErrInvalidCategoryID
|
||||
}
|
||||
|
||||
// Validate amount range
|
||||
if input.MinAmount != nil && input.MaxAmount != nil && *input.MinAmount > *input.MaxAmount {
|
||||
return nil, ErrInvalidAmountRange
|
||||
}
|
||||
|
||||
// Update fields
|
||||
rule.Keyword = keyword
|
||||
rule.CategoryID = input.CategoryID
|
||||
rule.MinAmount = input.MinAmount
|
||||
rule.MaxAmount = input.MaxAmount
|
||||
|
||||
if err := s.classificationRepo.Update(rule); err != nil {
|
||||
return nil, fmt.Errorf("failed to update classification rule: %w", err)
|
||||
}
|
||||
|
||||
// Reload to get updated category
|
||||
rule, err = s.classificationRepo.GetByID(userID, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to reload classification rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteRule deletes a classification rule by ID
|
||||
func (s *ClassificationService) DeleteRule(userID, id uint) error {
|
||||
err := s.classificationRepo.Delete(userID, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
||||
return ErrClassificationRuleNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to delete classification rule: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SuggestCategory suggests categories based on transaction note and amount
|
||||
// This is the core smart classification algorithm that runs entirely locally
|
||||
// Requirement 2.1.1: Recommend most likely category based on historical data
|
||||
// Requirement 2.1.2: Runs completely locally, no external data transmission
|
||||
// Requirement 2.1.4: Match based on note keywords and amount range
|
||||
func (s *ClassificationService) SuggestCategory(userID uint, note string, amount float64) ([]ClassificationSuggestion, error) {
|
||||
if strings.TrimSpace(note) == "" {
|
||||
return []ClassificationSuggestion{}, nil
|
||||
}
|
||||
|
||||
// Get all matching rules based on keyword and amount
|
||||
matchingRules, err := s.classificationRepo.GetMatchingRules(userID, note, amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get matching rules: %w", err)
|
||||
}
|
||||
|
||||
if len(matchingRules) == 0 {
|
||||
return []ClassificationSuggestion{}, nil
|
||||
}
|
||||
|
||||
// Calculate confidence scores and build suggestions
|
||||
suggestions := make([]ClassificationSuggestion, 0, len(matchingRules))
|
||||
|
||||
// Find the maximum hit count for normalization
|
||||
maxHitCount := 0
|
||||
for _, rule := range matchingRules {
|
||||
if rule.HitCount > maxHitCount {
|
||||
maxHitCount = rule.HitCount
|
||||
}
|
||||
}
|
||||
|
||||
for i := range matchingRules {
|
||||
rule := &matchingRules[i]
|
||||
confidence := s.calculateConfidence(note, amount, rule, maxHitCount)
|
||||
|
||||
matchReason := s.buildMatchReason(note, amount, rule)
|
||||
|
||||
suggestion := ClassificationSuggestion{
|
||||
CategoryID: rule.CategoryID,
|
||||
Category: &rule.Category,
|
||||
Confidence: confidence,
|
||||
MatchedRule: rule,
|
||||
MatchReason: matchReason,
|
||||
}
|
||||
suggestions = append(suggestions, suggestion)
|
||||
}
|
||||
|
||||
// Sort by confidence (highest first) - already partially sorted by hit_count from DB
|
||||
s.sortSuggestionsByConfidence(suggestions)
|
||||
|
||||
// Deduplicate by category ID, keeping highest confidence
|
||||
suggestions = s.deduplicateSuggestions(suggestions)
|
||||
|
||||
return suggestions, nil
|
||||
}
|
||||
|
||||
// calculateConfidence calculates a confidence score for a rule match
|
||||
// The score is based on:
|
||||
// 1. Keyword match quality (exact match vs partial match)
|
||||
// 2. Amount range match (if specified)
|
||||
// 3. Historical hit count (popularity)
|
||||
func (s *ClassificationService) calculateConfidence(note string, amount float64, rule *models.ClassificationRule, maxHitCount int) float64 {
|
||||
var confidence float64 = 0.0
|
||||
|
||||
// Base score for keyword match (0.3 - 0.5)
|
||||
noteLower := strings.ToLower(note)
|
||||
keywordLower := strings.ToLower(rule.Keyword)
|
||||
|
||||
if noteLower == keywordLower {
|
||||
// Exact match
|
||||
confidence += 0.5
|
||||
} else if strings.Contains(noteLower, keywordLower) {
|
||||
// Partial match - score based on keyword length relative to note
|
||||
keywordRatio := float64(len(rule.Keyword)) / float64(len(note))
|
||||
confidence += 0.3 + (0.2 * keywordRatio)
|
||||
}
|
||||
|
||||
// Amount range match bonus (0.0 - 0.3)
|
||||
amountBonus := 0.0
|
||||
hasAmountConstraint := rule.MinAmount != nil || rule.MaxAmount != nil
|
||||
|
||||
if hasAmountConstraint {
|
||||
inRange := true
|
||||
if rule.MinAmount != nil && amount < *rule.MinAmount {
|
||||
inRange = false
|
||||
}
|
||||
if rule.MaxAmount != nil && amount > *rule.MaxAmount {
|
||||
inRange = false
|
||||
}
|
||||
|
||||
if inRange {
|
||||
// Calculate how well the amount fits in the range
|
||||
if rule.MinAmount != nil && rule.MaxAmount != nil {
|
||||
rangeSize := *rule.MaxAmount - *rule.MinAmount
|
||||
if rangeSize > 0 {
|
||||
// Closer to the middle of the range = higher score
|
||||
midPoint := (*rule.MinAmount + *rule.MaxAmount) / 2
|
||||
distanceFromMid := abs(amount - midPoint)
|
||||
normalizedDistance := distanceFromMid / (rangeSize / 2)
|
||||
amountBonus = 0.3 * (1 - normalizedDistance)
|
||||
} else {
|
||||
amountBonus = 0.3 // Exact amount match
|
||||
}
|
||||
} else {
|
||||
amountBonus = 0.2 // Only one bound specified
|
||||
}
|
||||
}
|
||||
}
|
||||
confidence += amountBonus
|
||||
|
||||
// Historical popularity bonus (0.0 - 0.2)
|
||||
if maxHitCount > 0 {
|
||||
popularityRatio := float64(rule.HitCount) / float64(maxHitCount)
|
||||
confidence += 0.2 * popularityRatio
|
||||
}
|
||||
|
||||
// Cap confidence at 1.0
|
||||
if confidence > 1.0 {
|
||||
confidence = 1.0
|
||||
}
|
||||
|
||||
return confidence
|
||||
}
|
||||
|
||||
// buildMatchReason builds a human-readable explanation for why this category was suggested
|
||||
func (s *ClassificationService) buildMatchReason(note string, amount float64, rule *models.ClassificationRule) string {
|
||||
reasons := []string{}
|
||||
|
||||
// Keyword match reason
|
||||
noteLower := strings.ToLower(note)
|
||||
keywordLower := strings.ToLower(rule.Keyword)
|
||||
|
||||
if noteLower == keywordLower {
|
||||
reasons = append(reasons, fmt.Sprintf("备注完全匹配关键词'%s'", rule.Keyword))
|
||||
} else {
|
||||
reasons = append(reasons, fmt.Sprintf("备注包含关键词'%s'", rule.Keyword))
|
||||
}
|
||||
|
||||
// Amount range reason
|
||||
if rule.MinAmount != nil && rule.MaxAmount != nil {
|
||||
reasons = append(reasons, fmt.Sprintf("金额 %.2f 在范围 %.2f-%.2f 内", amount, *rule.MinAmount, *rule.MaxAmount))
|
||||
} else if rule.MinAmount != nil {
|
||||
reasons = append(reasons, fmt.Sprintf("金额 %.2f >= %.2f", amount, *rule.MinAmount))
|
||||
} else if rule.MaxAmount != nil {
|
||||
reasons = append(reasons, fmt.Sprintf("金额 %.2f <= %.2f", amount, *rule.MaxAmount))
|
||||
}
|
||||
|
||||
// Hit count reason
|
||||
if rule.HitCount > 0 {
|
||||
reasons = append(reasons, fmt.Sprintf("历史匹配 %d 次", rule.HitCount))
|
||||
}
|
||||
|
||||
return strings.Join(reasons, "; ")
|
||||
}
|
||||
|
||||
// sortSuggestionsByConfidence sorts suggestions by confidence in descending order
|
||||
func (s *ClassificationService) sortSuggestionsByConfidence(suggestions []ClassificationSuggestion) {
|
||||
// Simple bubble sort for small arrays (typically < 10 items)
|
||||
n := len(suggestions)
|
||||
for i := 0; i < n-1; i++ {
|
||||
for j := 0; j < n-i-1; j++ {
|
||||
if suggestions[j].Confidence < suggestions[j+1].Confidence {
|
||||
suggestions[j], suggestions[j+1] = suggestions[j+1], suggestions[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deduplicateSuggestions removes duplicate category suggestions, keeping the highest confidence
|
||||
func (s *ClassificationService) deduplicateSuggestions(suggestions []ClassificationSuggestion) []ClassificationSuggestion {
|
||||
seen := make(map[uint]bool)
|
||||
result := make([]ClassificationSuggestion, 0, len(suggestions))
|
||||
|
||||
for _, suggestion := range suggestions {
|
||||
if !seen[suggestion.CategoryID] {
|
||||
seen[suggestion.CategoryID] = true
|
||||
result = append(result, suggestion)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ConfirmSuggestion confirms a classification suggestion, incrementing the hit count
|
||||
// This is called when the user accepts a suggested category
|
||||
// Requirement 2.1.3: Update local classification model when user confirms/modifies
|
||||
func (s *ClassificationService) ConfirmSuggestion(userID, ruleID uint) error {
|
||||
err := s.classificationRepo.IncrementHitCount(userID, ruleID)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
||||
return ErrClassificationRuleNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to confirm suggestion: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LearnFromTransaction creates or updates a classification rule based on a confirmed transaction
|
||||
// This allows the system to learn from user behavior
|
||||
func (s *ClassificationService) LearnFromTransaction(userID uint, note string, amount float64, categoryID uint) error {
|
||||
if strings.TrimSpace(note) == "" {
|
||||
return nil // Nothing to learn from empty notes
|
||||
}
|
||||
|
||||
// Extract keywords from the note (simple approach: use the whole note as keyword)
|
||||
// In a more sophisticated implementation, we could use NLP to extract key phrases
|
||||
keyword := strings.TrimSpace(note)
|
||||
|
||||
// Check if a rule already exists for this keyword and category
|
||||
existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword)
|
||||
if err != nil && !errors.Is(err, repository.ErrClassificationRuleNotFound) {
|
||||
return fmt.Errorf("failed to check existing rule: %w", err)
|
||||
}
|
||||
|
||||
if existingRule != nil {
|
||||
// Rule exists - increment hit count if same category
|
||||
if existingRule.CategoryID == categoryID {
|
||||
return s.classificationRepo.IncrementHitCount(userID, existingRule.ID)
|
||||
}
|
||||
// Different category - could create a new rule or update existing
|
||||
// For now, we'll create a new rule for the new category
|
||||
}
|
||||
|
||||
// Create a new rule
|
||||
input := ClassificationRuleInput{
|
||||
UserID: userID,
|
||||
Keyword: keyword,
|
||||
CategoryID: categoryID,
|
||||
}
|
||||
|
||||
// Set amount range based on the transaction amount (±20% range)
|
||||
minAmount := amount * 0.8
|
||||
maxAmount := amount * 1.2
|
||||
input.MinAmount = &minAmount
|
||||
input.MaxAmount = &maxAmount
|
||||
|
||||
_, err = s.CreateRule(input)
|
||||
if err != nil {
|
||||
// If rule already exists (same keyword, same category), just increment hit count
|
||||
if errors.Is(err, ErrRuleAlreadyExists) {
|
||||
existingRule, err := s.classificationRepo.GetByExactKeyword(userID, keyword)
|
||||
if err == nil && existingRule.CategoryID == categoryID {
|
||||
return s.classificationRepo.IncrementHitCount(userID, existingRule.ID)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to learn from transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RuleExists checks if a classification rule exists by ID
|
||||
func (s *ClassificationService) RuleExists(userID, id uint) (bool, error) {
|
||||
exists, err := s.classificationRepo.ExistsByID(userID, id)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check rule existence: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// abs returns the absolute value of a float64
|
||||
func abs(x float64) float64 {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
Reference in New Issue
Block a user