package repository import ( "errors" "fmt" "time" "accounting-app/internal/models" "gorm.io/gorm" ) // Billing repository errors var ( ErrBillNotFound = errors.New("bill not found") ) // BillingRepository handles database operations for credit card bills type BillingRepository struct { db *gorm.DB } // NewBillingRepository creates a new BillingRepository instance func NewBillingRepository(db *gorm.DB) *BillingRepository { return &BillingRepository{db: db} } // Create creates a new bill in the database func (r *BillingRepository) Create(bill *models.CreditCardBill) error { if err := r.db.Create(bill).Error; err != nil { return fmt.Errorf("failed to create bill: %w", err) } return nil } // GetByID retrieves a bill by its ID func (r *BillingRepository) GetByID(userID uint, id uint) (*models.CreditCardBill, error) { var bill models.CreditCardBill if err := r.db.Preload("Account").Where("user_id = ?", userID).First(&bill, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrBillNotFound } return nil, fmt.Errorf("failed to get bill: %w", err) } return &bill, nil } // GetByAccountID retrieves all bills for a specific account func (r *BillingRepository) GetByAccountID(userID uint, accountID uint) ([]models.CreditCardBill, error) { var bills []models.CreditCardBill if err := r.db.Where("user_id = ? AND account_id = ?", userID, accountID). Order("billing_date DESC"). Preload("Account"). Find(&bills).Error; err != nil { return nil, fmt.Errorf("failed to get bills by account: %w", err) } return bills, nil } // GetByAccountIDAndDateRange retrieves bills for an account within a date range func (r *BillingRepository) GetByAccountIDAndDateRange(userID uint, accountID uint, startDate, endDate time.Time) ([]models.CreditCardBill, error) { var bills []models.CreditCardBill if err := r.db.Where("user_id = ? AND account_id = ? AND billing_date >= ? AND billing_date <= ?", userID, accountID, startDate, endDate). Order("billing_date DESC"). Preload("Account"). Find(&bills).Error; err != nil { return nil, fmt.Errorf("failed to get bills by date range: %w", err) } return bills, nil } // GetLatestByAccountID retrieves the most recent bill for an account func (r *BillingRepository) GetLatestByAccountID(userID uint, accountID uint) (*models.CreditCardBill, error) { var bill models.CreditCardBill if err := r.db.Where("user_id = ? AND account_id = ?", userID, accountID). Order("billing_date DESC"). Preload("Account"). First(&bill).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrBillNotFound } return nil, fmt.Errorf("failed to get latest bill: %w", err) } return &bill, nil } // GetPendingBills retrieves all pending bills (not yet paid) func (r *BillingRepository) GetPendingBills(userID uint) ([]models.CreditCardBill, error) { var bills []models.CreditCardBill if err := r.db.Where("user_id = ? AND status = ?", userID, models.BillStatusPending). Order("payment_due_date ASC"). Preload("Account"). Find(&bills).Error; err != nil { return nil, fmt.Errorf("failed to get pending bills: %w", err) } return bills, nil } // GetOverdueBills retrieves all overdue bills func (r *BillingRepository) GetOverdueBills(userID uint) ([]models.CreditCardBill, error) { var bills []models.CreditCardBill if err := r.db.Where("user_id = ? AND status = ?", userID, models.BillStatusOverdue). Order("payment_due_date ASC"). Preload("Account"). Find(&bills).Error; err != nil { return nil, fmt.Errorf("failed to get overdue bills: %w", err) } return bills, nil } // GetBillsDueInRange retrieves bills with payment due dates in a specific range func (r *BillingRepository) GetBillsDueInRange(userID uint, startDate, endDate time.Time) ([]models.CreditCardBill, error) { var bills []models.CreditCardBill if err := r.db.Where("user_id = ? AND payment_due_date >= ? AND payment_due_date <= ? AND status != ?", userID, startDate, endDate, models.BillStatusPaid). Order("payment_due_date ASC"). Preload("Account"). Find(&bills).Error; err != nil { return nil, fmt.Errorf("failed to get bills due in range: %w", err) } return bills, nil } // Update updates an existing bill func (r *BillingRepository) Update(bill *models.CreditCardBill) error { // First check if the bill exists var existing models.CreditCardBill if err := r.db.First(&existing, bill.ID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrBillNotFound } return fmt.Errorf("failed to check bill existence: %w", err) } // Update the bill if err := r.db.Save(bill).Error; err != nil { return fmt.Errorf("failed to update bill: %w", err) } return nil } // UpdateStatus updates the status of a bill func (r *BillingRepository) UpdateStatus(userID uint, id uint, status models.BillStatus) error { result := r.db.Model(&models.CreditCardBill{}).Where("user_id = ? AND id = ?", userID, id).Update("status", status) if result.Error != nil { return fmt.Errorf("failed to update bill status: %w", result.Error) } if result.RowsAffected == 0 { return ErrBillNotFound } return nil } // MarkAsPaid marks a bill as paid func (r *BillingRepository) MarkAsPaid(userID uint, id uint, paidAmount float64, paidAt time.Time) error { result := r.db.Model(&models.CreditCardBill{}).Where("user_id = ? AND id = ?", userID, id).Updates(map[string]interface{}{ "status": models.BillStatusPaid, "paid_amount": paidAmount, "paid_at": paidAt, }) if result.Error != nil { return fmt.Errorf("failed to mark bill as paid: %w", result.Error) } if result.RowsAffected == 0 { return ErrBillNotFound } return nil } // Delete deletes a bill by its ID func (r *BillingRepository) Delete(userID uint, id uint) error { // First check if the bill exists var bill models.CreditCardBill if err := r.db.Where("user_id = ?", userID).First(&bill, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrBillNotFound } return fmt.Errorf("failed to check bill existence: %w", err) } // Delete the bill (soft delete) if err := r.db.Delete(&bill).Error; err != nil { return fmt.Errorf("failed to delete bill: %w", err) } return nil } // ExistsByAccountAndBillingDate checks if a bill exists for an account on a specific billing date func (r *BillingRepository) ExistsByAccountAndBillingDate(userID uint, accountID uint, billingDate time.Time) (bool, error) { var count int64 if err := r.db.Model(&models.CreditCardBill{}). Where("user_id = ? AND account_id = ? AND billing_date = ?", userID, accountID, billingDate). Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check bill existence: %w", err) } return count > 0, nil } // CountByAccountID returns the count of bills for an account func (r *BillingRepository) CountByAccountID(userID uint, accountID uint) (int64, error) { var count int64 if err := r.db.Model(&models.CreditCardBill{}).Where("user_id = ? AND account_id = ?", userID, accountID).Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to count bills by account: %w", err) } return count, nil }