package repository import ( "errors" "fmt" "time" "accounting-app/internal/models" "gorm.io/gorm" ) // Recurring transaction repository errors var ( ErrRecurringTransactionNotFound = errors.New("recurring transaction not found") ) // RecurringTransactionRepository handles database operations for recurring transactions type RecurringTransactionRepository struct { db *gorm.DB } // NewRecurringTransactionRepository creates a new RecurringTransactionRepository instance func NewRecurringTransactionRepository(db *gorm.DB) *RecurringTransactionRepository { return &RecurringTransactionRepository{db: db} } // Create creates a new recurring transaction in the database func (r *RecurringTransactionRepository) Create(recurringTransaction *models.RecurringTransaction) error { if err := r.db.Create(recurringTransaction).Error; err != nil { return fmt.Errorf("failed to create recurring transaction: %w", err) } return nil } // GetByID retrieves a recurring transaction by its ID func (r *RecurringTransactionRepository) GetByID(userID uint, id uint) (*models.RecurringTransaction, error) { var recurringTransaction models.RecurringTransaction if err := r.db.Where("user_id = ?", userID).First(&recurringTransaction, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrRecurringTransactionNotFound } return nil, fmt.Errorf("failed to get recurring transaction: %w", err) } return &recurringTransaction, nil } // GetByIDWithRelations retrieves a recurring transaction by its ID with all relations preloaded func (r *RecurringTransactionRepository) GetByIDWithRelations(userID uint, id uint) (*models.RecurringTransaction, error) { var recurringTransaction models.RecurringTransaction if err := r.db.Where("user_id = ?", userID). Preload("Category"). Preload("Account"). First(&recurringTransaction, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrRecurringTransactionNotFound } return nil, fmt.Errorf("failed to get recurring transaction with relations: %w", err) } return &recurringTransaction, nil } // Update updates an existing recurring transaction in the database func (r *RecurringTransactionRepository) Update(recurringTransaction *models.RecurringTransaction) error { // First check if the recurring transaction exists var existing models.RecurringTransaction if err := r.db.Where("user_id = ?", recurringTransaction.UserID).First(&existing, recurringTransaction.ID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrRecurringTransactionNotFound } return fmt.Errorf("failed to check recurring transaction existence: %w", err) } // Update the recurring transaction if err := r.db.Save(recurringTransaction).Error; err != nil { return fmt.Errorf("failed to update recurring transaction: %w", err) } return nil } // Delete deletes a recurring transaction by its ID (soft delete) func (r *RecurringTransactionRepository) Delete(userID uint, id uint) error { // First check if the recurring transaction exists var recurringTransaction models.RecurringTransaction if err := r.db.Where("user_id = ?", userID).First(&recurringTransaction, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrRecurringTransactionNotFound } return fmt.Errorf("failed to check recurring transaction existence: %w", err) } // Delete the recurring transaction (soft delete due to gorm.DeletedAt field) if err := r.db.Delete(&recurringTransaction).Error; err != nil { return fmt.Errorf("failed to delete recurring transaction: %w", err) } return nil } // List retrieves all recurring transactions for a user func (r *RecurringTransactionRepository) List(userID uint) ([]models.RecurringTransaction, error) { var recurringTransactions []models.RecurringTransaction if err := r.db.Where("user_id = ?", userID). Preload("Category"). Preload("Account"). Order("next_occurrence ASC"). Find(&recurringTransactions).Error; err != nil { return nil, fmt.Errorf("failed to list recurring transactions: %w", err) } return recurringTransactions, nil } // GetActive retrieves all active recurring transactions for a user func (r *RecurringTransactionRepository) GetActive(userID uint) ([]models.RecurringTransaction, error) { var recurringTransactions []models.RecurringTransaction if err := r.db.Where("user_id = ? AND is_active = ?", userID, true). Preload("Category"). Preload("Account"). Order("next_occurrence ASC"). Find(&recurringTransactions).Error; err != nil { return nil, fmt.Errorf("failed to get active recurring transactions: %w", err) } return recurringTransactions, nil } // GetDueTransactions retrieves all active recurring transactions that are due (next_occurrence <= now) for a user func (r *RecurringTransactionRepository) GetDueTransactions(userID uint, now time.Time) ([]models.RecurringTransaction, error) { var recurringTransactions []models.RecurringTransaction if err := r.db.Where("user_id = ? AND is_active = ? AND next_occurrence <= ?", userID, true, now). Preload("Category"). Preload("Account"). Order("next_occurrence ASC"). Find(&recurringTransactions).Error; err != nil { return nil, fmt.Errorf("failed to get due recurring transactions: %w", err) } return recurringTransactions, nil } // GetByAccountID retrieves all recurring transactions for a specific account and user func (r *RecurringTransactionRepository) GetByAccountID(userID, accountID uint) ([]models.RecurringTransaction, error) { var recurringTransactions []models.RecurringTransaction if err := r.db.Where("user_id = ? AND account_id = ?", userID, accountID). Preload("Category"). Order("next_occurrence ASC"). Find(&recurringTransactions).Error; err != nil { return nil, fmt.Errorf("failed to get recurring transactions by account: %w", err) } return recurringTransactions, nil } // GetByCategoryID retrieves all recurring transactions for a specific category and user func (r *RecurringTransactionRepository) GetByCategoryID(userID, categoryID uint) ([]models.RecurringTransaction, error) { var recurringTransactions []models.RecurringTransaction if err := r.db.Where("user_id = ? AND category_id = ?", userID, categoryID). Preload("Account"). Order("next_occurrence ASC"). Find(&recurringTransactions).Error; err != nil { return nil, fmt.Errorf("failed to get recurring transactions by category: %w", err) } return recurringTransactions, nil } // ExistsByID checks if a recurring transaction with the given ID exists for a user func (r *RecurringTransactionRepository) ExistsByID(userID, id uint) (bool, error) { var count int64 if err := r.db.Model(&models.RecurringTransaction{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check recurring transaction existence: %w", err) } return count > 0, nil } // CountByAccountID returns the count of recurring transactions for an account and user func (r *RecurringTransactionRepository) CountByAccountID(userID, accountID uint) (int64, error) { var count int64 if err := r.db.Model(&models.RecurringTransaction{}). Where("user_id = ? AND account_id = ?", userID, accountID). Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to count recurring transactions by account: %w", err) } return count, nil } // CountByCategoryID returns the count of recurring transactions for a category and user func (r *RecurringTransactionRepository) CountByCategoryID(userID, categoryID uint) (int64, error) { var count int64 if err := r.db.Model(&models.RecurringTransaction{}). Where("user_id = ? AND category_id = ?", userID, categoryID). Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to count recurring transactions by category: %w", err) } return count, nil }