package repository import ( "errors" "fmt" "accounting-app/internal/models" "gorm.io/gorm" ) // Classification repository errors var ( ErrClassificationRuleNotFound = errors.New("classification rule not found") ) // ClassificationRepository handles database operations for classification rules type ClassificationRepository struct { db *gorm.DB } // NewClassificationRepository creates a new ClassificationRepository instance func NewClassificationRepository(db *gorm.DB) *ClassificationRepository { return &ClassificationRepository{db: db} } // Create creates a new classification rule in the database func (r *ClassificationRepository) Create(rule *models.ClassificationRule) error { if err := r.db.Create(rule).Error; err != nil { return fmt.Errorf("failed to create classification rule: %w", err) } return nil } // GetByID retrieves a classification rule by its ID func (r *ClassificationRepository) GetByID(userID uint, id uint) (*models.ClassificationRule, error) { var rule models.ClassificationRule if err := r.db.Preload("Category").Where("user_id = ?", userID).First(&rule, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrClassificationRuleNotFound } return nil, fmt.Errorf("failed to get classification rule: %w", err) } return &rule, nil } // GetAll retrieves all classification rules from the database func (r *ClassificationRepository) GetAll(userID uint) ([]models.ClassificationRule, error) { var rules []models.ClassificationRule if err := r.db.Preload("Category").Where("user_id = ?", userID).Order("hit_count DESC").Find(&rules).Error; err != nil { return nil, fmt.Errorf("failed to get classification rules: %w", err) } return rules, nil } // GetByKeyword retrieves all classification rules that match a keyword (case-insensitive partial match) func (r *ClassificationRepository) GetByKeyword(userID uint, keyword string) ([]models.ClassificationRule, error) { var rules []models.ClassificationRule if err := r.db.Preload("Category"). Where("user_id = ? AND LOWER(keyword) LIKE LOWER(?)", userID, "%"+keyword+"%"). Order("hit_count DESC"). Find(&rules).Error; err != nil { return nil, fmt.Errorf("failed to get classification rules by keyword: %w", err) } return rules, nil } // GetByExactKeyword retrieves a classification rule by exact keyword match (case-insensitive) func (r *ClassificationRepository) GetByExactKeyword(userID uint, keyword string) (*models.ClassificationRule, error) { var rule models.ClassificationRule if err := r.db.Preload("Category"). Where("user_id = ? AND LOWER(keyword) = LOWER(?)", userID, keyword). First(&rule).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrClassificationRuleNotFound } return nil, fmt.Errorf("failed to get classification rule by exact keyword: %w", err) } return &rule, nil } // GetByCategoryID retrieves all classification rules for a specific category func (r *ClassificationRepository) GetByCategoryID(userID uint, categoryID uint) ([]models.ClassificationRule, error) { var rules []models.ClassificationRule if err := r.db.Preload("Category"). Where("user_id = ? AND category_id = ?", userID, categoryID). Order("hit_count DESC"). Find(&rules).Error; err != nil { return nil, fmt.Errorf("failed to get classification rules by category: %w", err) } return rules, nil } // GetMatchingRules retrieves all rules where the keyword is contained in the given note // and the amount falls within the min/max range (if specified) func (r *ClassificationRepository) GetMatchingRules(userID uint, note string, amount float64) ([]models.ClassificationRule, error) { var rules []models.ClassificationRule // Find rules where the keyword is contained in the note (case-insensitive) // and the amount is within the specified range (if min/max are set) query := r.db.Preload("Category"). Where("user_id = ?", userID). Where("LOWER(?) LIKE '%' || LOWER(keyword) || '%'", note). Where("(min_amount IS NULL OR ? >= min_amount)", amount). Where("(max_amount IS NULL OR ? <= max_amount)", amount). Order("hit_count DESC") if err := query.Find(&rules).Error; err != nil { return nil, fmt.Errorf("failed to get matching classification rules: %w", err) } return rules, nil } // Update updates an existing classification rule in the database func (r *ClassificationRepository) Update(rule *models.ClassificationRule) error { // First check if the rule exists var existing models.ClassificationRule if err := r.db.First(&existing, rule.ID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrClassificationRuleNotFound } return fmt.Errorf("failed to check classification rule existence: %w", err) } // Update the rule using Updates to avoid issues with preloaded relationships // We explicitly update only the fields we want to change updates := map[string]interface{}{ "keyword": rule.Keyword, "category_id": rule.CategoryID, "min_amount": rule.MinAmount, "max_amount": rule.MaxAmount, } if err := r.db.Model(&models.ClassificationRule{}).Where("id = ?", rule.ID).Updates(updates).Error; err != nil { return fmt.Errorf("failed to update classification rule: %w", err) } return nil } // IncrementHitCount increments the hit count for a classification rule func (r *ClassificationRepository) IncrementHitCount(userID uint, id uint) error { result := r.db.Model(&models.ClassificationRule{}). Where("user_id = ? AND id = ?", userID, id). UpdateColumn("hit_count", gorm.Expr("hit_count + 1")) if result.Error != nil { return fmt.Errorf("failed to increment hit count: %w", result.Error) } if result.RowsAffected == 0 { return ErrClassificationRuleNotFound } return nil } // Delete deletes a classification rule by its ID func (r *ClassificationRepository) Delete(userID uint, id uint) error { // First check if the rule exists var rule models.ClassificationRule if err := r.db.Where("user_id = ?", userID).First(&rule, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrClassificationRuleNotFound } return fmt.Errorf("failed to check classification rule existence: %w", err) } // Delete the rule if err := r.db.Delete(&rule).Error; err != nil { return fmt.Errorf("failed to delete classification rule: %w", err) } return nil } // ExistsByID checks if a classification rule with the given ID exists func (r *ClassificationRepository) ExistsByID(userID uint, id uint) (bool, error) { var count int64 if err := r.db.Model(&models.ClassificationRule{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check classification rule existence: %w", err) } return count > 0, nil } // ExistsByKeywordAndCategory checks if a rule with the given keyword and category already exists func (r *ClassificationRepository) ExistsByKeywordAndCategory(userID uint, keyword string, categoryID uint) (bool, error) { var count int64 if err := r.db.Model(&models.ClassificationRule{}). Where("user_id = ? AND LOWER(keyword) = LOWER(?) AND category_id = ?", userID, keyword, categoryID). Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check classification rule existence: %w", err) } return count > 0, nil } // DeleteByCategoryID deletes all classification rules for a specific category func (r *ClassificationRepository) DeleteByCategoryID(userID uint, categoryID uint) error { if err := r.db.Where("user_id = ? AND category_id = ?", userID, categoryID).Delete(&models.ClassificationRule{}).Error; err != nil { return fmt.Errorf("failed to delete classification rules by category: %w", err) } return nil }