This commit is contained in:
2026-01-25 21:59:00 +08:00
parent 7fd537bef3
commit 4cad3f0250
118 changed files with 30473 additions and 0 deletions

View File

@@ -0,0 +1,199 @@
package repository
import (
"errors"
"fmt"
"accounting-app/internal/models"
"gorm.io/gorm"
)
// Classification repository errors
var (
ErrClassificationRuleNotFound = errors.New("classification rule not found")
)
// ClassificationRepository handles database operations for classification rules
type ClassificationRepository struct {
db *gorm.DB
}
// NewClassificationRepository creates a new ClassificationRepository instance
func NewClassificationRepository(db *gorm.DB) *ClassificationRepository {
return &ClassificationRepository{db: db}
}
// Create creates a new classification rule in the database
func (r *ClassificationRepository) Create(rule *models.ClassificationRule) error {
if err := r.db.Create(rule).Error; err != nil {
return fmt.Errorf("failed to create classification rule: %w", err)
}
return nil
}
// GetByID retrieves a classification rule by its ID
func (r *ClassificationRepository) GetByID(userID uint, id uint) (*models.ClassificationRule, error) {
var rule models.ClassificationRule
if err := r.db.Preload("Category").Where("user_id = ?", userID).First(&rule, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrClassificationRuleNotFound
}
return nil, fmt.Errorf("failed to get classification rule: %w", err)
}
return &rule, nil
}
// GetAll retrieves all classification rules from the database
func (r *ClassificationRepository) GetAll(userID uint) ([]models.ClassificationRule, error) {
var rules []models.ClassificationRule
if err := r.db.Preload("Category").Where("user_id = ?", userID).Order("hit_count DESC").Find(&rules).Error; err != nil {
return nil, fmt.Errorf("failed to get classification rules: %w", err)
}
return rules, nil
}
// GetByKeyword retrieves all classification rules that match a keyword (case-insensitive partial match)
func (r *ClassificationRepository) GetByKeyword(userID uint, keyword string) ([]models.ClassificationRule, error) {
var rules []models.ClassificationRule
if err := r.db.Preload("Category").
Where("user_id = ? AND LOWER(keyword) LIKE LOWER(?)", userID, "%"+keyword+"%").
Order("hit_count DESC").
Find(&rules).Error; err != nil {
return nil, fmt.Errorf("failed to get classification rules by keyword: %w", err)
}
return rules, nil
}
// GetByExactKeyword retrieves a classification rule by exact keyword match (case-insensitive)
func (r *ClassificationRepository) GetByExactKeyword(userID uint, keyword string) (*models.ClassificationRule, error) {
var rule models.ClassificationRule
if err := r.db.Preload("Category").
Where("user_id = ? AND LOWER(keyword) = LOWER(?)", userID, keyword).
First(&rule).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrClassificationRuleNotFound
}
return nil, fmt.Errorf("failed to get classification rule by exact keyword: %w", err)
}
return &rule, nil
}
// GetByCategoryID retrieves all classification rules for a specific category
func (r *ClassificationRepository) GetByCategoryID(userID uint, categoryID uint) ([]models.ClassificationRule, error) {
var rules []models.ClassificationRule
if err := r.db.Preload("Category").
Where("user_id = ? AND category_id = ?", userID, categoryID).
Order("hit_count DESC").
Find(&rules).Error; err != nil {
return nil, fmt.Errorf("failed to get classification rules by category: %w", err)
}
return rules, nil
}
// GetMatchingRules retrieves all rules where the keyword is contained in the given note
// and the amount falls within the min/max range (if specified)
func (r *ClassificationRepository) GetMatchingRules(userID uint, note string, amount float64) ([]models.ClassificationRule, error) {
var rules []models.ClassificationRule
// Find rules where the keyword is contained in the note (case-insensitive)
// and the amount is within the specified range (if min/max are set)
query := r.db.Preload("Category").
Where("user_id = ?", userID).
Where("LOWER(?) LIKE '%' || LOWER(keyword) || '%'", note).
Where("(min_amount IS NULL OR ? >= min_amount)", amount).
Where("(max_amount IS NULL OR ? <= max_amount)", amount).
Order("hit_count DESC")
if err := query.Find(&rules).Error; err != nil {
return nil, fmt.Errorf("failed to get matching classification rules: %w", err)
}
return rules, nil
}
// Update updates an existing classification rule in the database
func (r *ClassificationRepository) Update(rule *models.ClassificationRule) error {
// First check if the rule exists
var existing models.ClassificationRule
if err := r.db.First(&existing, rule.ID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrClassificationRuleNotFound
}
return fmt.Errorf("failed to check classification rule existence: %w", err)
}
// Update the rule using Updates to avoid issues with preloaded relationships
// We explicitly update only the fields we want to change
updates := map[string]interface{}{
"keyword": rule.Keyword,
"category_id": rule.CategoryID,
"min_amount": rule.MinAmount,
"max_amount": rule.MaxAmount,
}
if err := r.db.Model(&models.ClassificationRule{}).Where("id = ?", rule.ID).Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update classification rule: %w", err)
}
return nil
}
// IncrementHitCount increments the hit count for a classification rule
func (r *ClassificationRepository) IncrementHitCount(userID uint, id uint) error {
result := r.db.Model(&models.ClassificationRule{}).
Where("user_id = ? AND id = ?", userID, id).
UpdateColumn("hit_count", gorm.Expr("hit_count + 1"))
if result.Error != nil {
return fmt.Errorf("failed to increment hit count: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrClassificationRuleNotFound
}
return nil
}
// Delete deletes a classification rule by its ID
func (r *ClassificationRepository) Delete(userID uint, id uint) error {
// First check if the rule exists
var rule models.ClassificationRule
if err := r.db.Where("user_id = ?", userID).First(&rule, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrClassificationRuleNotFound
}
return fmt.Errorf("failed to check classification rule existence: %w", err)
}
// Delete the rule
if err := r.db.Delete(&rule).Error; err != nil {
return fmt.Errorf("failed to delete classification rule: %w", err)
}
return nil
}
// ExistsByID checks if a classification rule with the given ID exists
func (r *ClassificationRepository) ExistsByID(userID uint, id uint) (bool, error) {
var count int64
if err := r.db.Model(&models.ClassificationRule{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil {
return false, fmt.Errorf("failed to check classification rule existence: %w", err)
}
return count > 0, nil
}
// ExistsByKeywordAndCategory checks if a rule with the given keyword and category already exists
func (r *ClassificationRepository) ExistsByKeywordAndCategory(userID uint, keyword string, categoryID uint) (bool, error) {
var count int64
if err := r.db.Model(&models.ClassificationRule{}).
Where("user_id = ? AND LOWER(keyword) = LOWER(?) AND category_id = ?", userID, keyword, categoryID).
Count(&count).Error; err != nil {
return false, fmt.Errorf("failed to check classification rule existence: %w", err)
}
return count > 0, nil
}
// DeleteByCategoryID deletes all classification rules for a specific category
func (r *ClassificationRepository) DeleteByCategoryID(userID uint, categoryID uint) error {
if err := r.db.Where("user_id = ? AND category_id = ?", userID, categoryID).Delete(&models.ClassificationRule{}).Error; err != nil {
return fmt.Errorf("failed to delete classification rules by category: %w", err)
}
return nil
}