package repository import ( "errors" "fmt" "accounting-app/internal/models" "gorm.io/gorm" ) // Common repository errors var ( ErrAccountNotFound = errors.New("account not found") ErrAccountInUse = errors.New("account is in use and cannot be deleted") ) // AccountRepository handles database operations for accounts type AccountRepository struct { db *gorm.DB } // NewAccountRepository creates a new AccountRepository instance func NewAccountRepository(db *gorm.DB) *AccountRepository { return &AccountRepository{db: db} } // Create creates a new account in the database func (r *AccountRepository) Create(account *models.Account) error { if err := r.db.Create(account).Error; err != nil { return fmt.Errorf("failed to create account: %w", err) } return nil } // GetByID retrieves an account by its ID func (r *AccountRepository) GetByID(userID uint, id uint) (*models.Account, error) { var account models.Account if err := r.db.Where("user_id = ?", userID).First(&account, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrAccountNotFound } return nil, fmt.Errorf("failed to get account: %w", err) } return &account, nil } // GetAll retrieves all accounts for a user // Feature: accounting-feature-upgrade - Orders by sort_order for consistent display // Validates: Requirements 1.3, 1.4 func (r *AccountRepository) GetAll(userID uint) ([]models.Account, error) { var accounts []models.Account if err := r.db.Where("user_id = ?", userID).Order("sort_order ASC, created_at DESC").Find(&accounts).Error; err != nil { return nil, fmt.Errorf("failed to get accounts: %w", err) } return accounts, nil } // Update updates an existing account in the database func (r *AccountRepository) Update(account *models.Account) error { // First check if the account exists var existing models.Account if err := r.db.Where("user_id = ?", account.UserID).First(&existing, account.ID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrAccountNotFound } return fmt.Errorf("failed to check account existence: %w", err) } // Update the account if err := r.db.Save(account).Error; err != nil { return fmt.Errorf("failed to update account: %w", err) } return nil } // Delete deletes an account by its ID func (r *AccountRepository) Delete(userID uint, id uint) error { // First check if the account exists var account models.Account if err := r.db.Where("user_id = ?", userID).First(&account, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrAccountNotFound } return fmt.Errorf("failed to check account existence: %w", err) } // Check if there are any transactions associated with this account var transactionCount int64 if err := r.db.Model(&models.Transaction{}).Where("account_id = ? OR to_account_id = ?", id, id).Count(&transactionCount).Error; err != nil { return fmt.Errorf("failed to check account transactions: %w", err) } if transactionCount > 0 { return ErrAccountInUse } // Check if there are any recurring transactions associated with this account var recurringCount int64 if err := r.db.Model(&models.RecurringTransaction{}).Where("account_id = ?", id).Count(&recurringCount).Error; err != nil { return fmt.Errorf("failed to check account recurring transactions: %w", err) } if recurringCount > 0 { return ErrAccountInUse } // Delete the account (soft delete due to gorm.DeletedAt field) if err := r.db.Delete(&account).Error; err != nil { return fmt.Errorf("failed to delete account: %w", err) } return nil } // GetByType retrieves all accounts of a specific type for a user func (r *AccountRepository) GetByType(userID uint, accountType models.AccountType) ([]models.Account, error) { var accounts []models.Account if err := r.db.Where("user_id = ? AND type = ?", userID, accountType).Order("created_at DESC").Find(&accounts).Error; err != nil { return nil, fmt.Errorf("failed to get accounts by type: %w", err) } return accounts, nil } // GetByCurrency retrieves all accounts with a specific currency for a user func (r *AccountRepository) GetByCurrency(userID uint, currency models.Currency) ([]models.Account, error) { var accounts []models.Account if err := r.db.Where("user_id = ? AND currency = ?", userID, currency).Order("created_at DESC").Find(&accounts).Error; err != nil { return nil, fmt.Errorf("failed to get accounts by currency: %w", err) } return accounts, nil } // GetCreditAccounts retrieves all credit-type accounts (credit cards and credit lines) for a user func (r *AccountRepository) GetCreditAccounts(userID uint) ([]models.Account, error) { var accounts []models.Account if err := r.db.Where("user_id = ? AND is_credit = ?", userID, true).Order("created_at DESC").Find(&accounts).Error; err != nil { return nil, fmt.Errorf("failed to get credit accounts: %w", err) } return accounts, nil } // GetTotalBalance calculates the total balance across all accounts for a user // Returns total assets (positive balances) and total liabilities (negative balances) func (r *AccountRepository) GetTotalBalance(userID uint) (assets float64, liabilities float64, err error) { var accounts []models.Account if err := r.db.Where("user_id = ?", userID).Find(&accounts).Error; err != nil { return 0, 0, fmt.Errorf("failed to get accounts for balance calculation: %w", err) } for _, account := range accounts { if account.Balance >= 0 { assets += account.Balance } else { liabilities += -account.Balance // Convert to positive for liabilities } } return assets, liabilities, nil } // UpdateBalance updates only the balance field of an account func (r *AccountRepository) UpdateBalance(userID uint, id uint, newBalance float64) error { result := r.db.Model(&models.Account{}).Where("user_id = ? AND id = ?", userID, id).Update("balance", newBalance) if result.Error != nil { return fmt.Errorf("failed to update account balance: %w", result.Error) } if result.RowsAffected == 0 { return ErrAccountNotFound } return nil } // ExistsByID checks if an account with the given ID exists func (r *AccountRepository) ExistsByID(userID uint, id uint) (bool, error) { var count int64 if err := r.db.Model(&models.Account{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check account existence: %w", err) } return count > 0, nil } // ExistsByName checks if an account with the given name exists for a user func (r *AccountRepository) ExistsByName(userID uint, name string) (bool, error) { var count int64 if err := r.db.Model(&models.Account{}).Where("user_id = ? AND name = ?", userID, name).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check account name existence: %w", err) } return count > 0, nil } // ExistsByNameExcludingID checks if an account with the given name exists, excluding a specific ID, for a user // This is useful for update operations to check for duplicate names func (r *AccountRepository) ExistsByNameExcludingID(userID uint, name string, excludeID uint) (bool, error) { var count int64 if err := r.db.Model(&models.Account{}).Where("user_id = ? AND name = ? AND id != ?", userID, name, excludeID).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check account name existence: %w", err) } return count > 0, nil } // GetByName retrieves an account by its name for a user func (r *AccountRepository) GetByName(userID uint, name string) (*models.Account, error) { var account models.Account if err := r.db.Where("user_id = ? AND name = ?", userID, name).First(&account).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrAccountNotFound } return nil, fmt.Errorf("failed to get account by name: %w", err) } return &account, nil } // UpdateSortOrder updates the sort_order field for a specific account // Feature: accounting-feature-upgrade // Validates: Requirements 1.3, 1.4 func (r *AccountRepository) UpdateSortOrder(userID uint, id uint, sortOrder int) error { result := r.db.Model(&models.Account{}).Where("user_id = ? AND id = ?", userID, id).Update("sort_order", sortOrder) if result.Error != nil { return fmt.Errorf("failed to update account sort order: %w", result.Error) } if result.RowsAffected == 0 { return ErrAccountNotFound } return nil }