Files
Novault-backend/internal/repository/billing_repository.go
2026-01-25 21:59:00 +08:00

207 lines
7.0 KiB
Go

package repository
import (
"errors"
"fmt"
"time"
"accounting-app/internal/models"
"gorm.io/gorm"
)
// Billing repository errors
var (
ErrBillNotFound = errors.New("bill not found")
)
// BillingRepository handles database operations for credit card bills
type BillingRepository struct {
db *gorm.DB
}
// NewBillingRepository creates a new BillingRepository instance
func NewBillingRepository(db *gorm.DB) *BillingRepository {
return &BillingRepository{db: db}
}
// Create creates a new bill in the database
func (r *BillingRepository) Create(bill *models.CreditCardBill) error {
if err := r.db.Create(bill).Error; err != nil {
return fmt.Errorf("failed to create bill: %w", err)
}
return nil
}
// GetByID retrieves a bill by its ID
func (r *BillingRepository) GetByID(userID uint, id uint) (*models.CreditCardBill, error) {
var bill models.CreditCardBill
if err := r.db.Preload("Account").Where("user_id = ?", userID).First(&bill, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrBillNotFound
}
return nil, fmt.Errorf("failed to get bill: %w", err)
}
return &bill, nil
}
// GetByAccountID retrieves all bills for a specific account
func (r *BillingRepository) GetByAccountID(userID uint, accountID uint) ([]models.CreditCardBill, error) {
var bills []models.CreditCardBill
if err := r.db.Where("user_id = ? AND account_id = ?", userID, accountID).
Order("billing_date DESC").
Preload("Account").
Find(&bills).Error; err != nil {
return nil, fmt.Errorf("failed to get bills by account: %w", err)
}
return bills, nil
}
// GetByAccountIDAndDateRange retrieves bills for an account within a date range
func (r *BillingRepository) GetByAccountIDAndDateRange(userID uint, accountID uint, startDate, endDate time.Time) ([]models.CreditCardBill, error) {
var bills []models.CreditCardBill
if err := r.db.Where("user_id = ? AND account_id = ? AND billing_date >= ? AND billing_date <= ?", userID, accountID, startDate, endDate).
Order("billing_date DESC").
Preload("Account").
Find(&bills).Error; err != nil {
return nil, fmt.Errorf("failed to get bills by date range: %w", err)
}
return bills, nil
}
// GetLatestByAccountID retrieves the most recent bill for an account
func (r *BillingRepository) GetLatestByAccountID(userID uint, accountID uint) (*models.CreditCardBill, error) {
var bill models.CreditCardBill
if err := r.db.Where("user_id = ? AND account_id = ?", userID, accountID).
Order("billing_date DESC").
Preload("Account").
First(&bill).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrBillNotFound
}
return nil, fmt.Errorf("failed to get latest bill: %w", err)
}
return &bill, nil
}
// GetPendingBills retrieves all pending bills (not yet paid)
func (r *BillingRepository) GetPendingBills(userID uint) ([]models.CreditCardBill, error) {
var bills []models.CreditCardBill
if err := r.db.Where("user_id = ? AND status = ?", userID, models.BillStatusPending).
Order("payment_due_date ASC").
Preload("Account").
Find(&bills).Error; err != nil {
return nil, fmt.Errorf("failed to get pending bills: %w", err)
}
return bills, nil
}
// GetOverdueBills retrieves all overdue bills
func (r *BillingRepository) GetOverdueBills(userID uint) ([]models.CreditCardBill, error) {
var bills []models.CreditCardBill
if err := r.db.Where("user_id = ? AND status = ?", userID, models.BillStatusOverdue).
Order("payment_due_date ASC").
Preload("Account").
Find(&bills).Error; err != nil {
return nil, fmt.Errorf("failed to get overdue bills: %w", err)
}
return bills, nil
}
// GetBillsDueInRange retrieves bills with payment due dates in a specific range
func (r *BillingRepository) GetBillsDueInRange(userID uint, startDate, endDate time.Time) ([]models.CreditCardBill, error) {
var bills []models.CreditCardBill
if err := r.db.Where("user_id = ? AND payment_due_date >= ? AND payment_due_date <= ? AND status != ?",
userID, startDate, endDate, models.BillStatusPaid).
Order("payment_due_date ASC").
Preload("Account").
Find(&bills).Error; err != nil {
return nil, fmt.Errorf("failed to get bills due in range: %w", err)
}
return bills, nil
}
// Update updates an existing bill
func (r *BillingRepository) Update(bill *models.CreditCardBill) error {
// First check if the bill exists
var existing models.CreditCardBill
if err := r.db.First(&existing, bill.ID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrBillNotFound
}
return fmt.Errorf("failed to check bill existence: %w", err)
}
// Update the bill
if err := r.db.Save(bill).Error; err != nil {
return fmt.Errorf("failed to update bill: %w", err)
}
return nil
}
// UpdateStatus updates the status of a bill
func (r *BillingRepository) UpdateStatus(userID uint, id uint, status models.BillStatus) error {
result := r.db.Model(&models.CreditCardBill{}).Where("user_id = ? AND id = ?", userID, id).Update("status", status)
if result.Error != nil {
return fmt.Errorf("failed to update bill status: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrBillNotFound
}
return nil
}
// MarkAsPaid marks a bill as paid
func (r *BillingRepository) MarkAsPaid(userID uint, id uint, paidAmount float64, paidAt time.Time) error {
result := r.db.Model(&models.CreditCardBill{}).Where("user_id = ? AND id = ?", userID, id).Updates(map[string]interface{}{
"status": models.BillStatusPaid,
"paid_amount": paidAmount,
"paid_at": paidAt,
})
if result.Error != nil {
return fmt.Errorf("failed to mark bill as paid: %w", result.Error)
}
if result.RowsAffected == 0 {
return ErrBillNotFound
}
return nil
}
// Delete deletes a bill by its ID
func (r *BillingRepository) Delete(userID uint, id uint) error {
// First check if the bill exists
var bill models.CreditCardBill
if err := r.db.Where("user_id = ?", userID).First(&bill, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrBillNotFound
}
return fmt.Errorf("failed to check bill existence: %w", err)
}
// Delete the bill (soft delete)
if err := r.db.Delete(&bill).Error; err != nil {
return fmt.Errorf("failed to delete bill: %w", err)
}
return nil
}
// ExistsByAccountAndBillingDate checks if a bill exists for an account on a specific billing date
func (r *BillingRepository) ExistsByAccountAndBillingDate(userID uint, accountID uint, billingDate time.Time) (bool, error) {
var count int64
if err := r.db.Model(&models.CreditCardBill{}).
Where("user_id = ? AND account_id = ? AND billing_date = ?", userID, accountID, billingDate).
Count(&count).Error; err != nil {
return false, fmt.Errorf("failed to check bill existence: %w", err)
}
return count > 0, nil
}
// CountByAccountID returns the count of bills for an account
func (r *BillingRepository) CountByAccountID(userID uint, accountID uint) (int64, error) {
var count int64
if err := r.db.Model(&models.CreditCardBill{}).Where("user_id = ? AND account_id = ?", userID, accountID).Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to count bills by account: %w", err)
}
return count, nil
}