529 lines
17 KiB
Go
529 lines
17 KiB
Go
package repository
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"accounting-app/internal/models"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Transaction repository errors
|
|
var (
|
|
ErrTransactionNotFound = errors.New("transaction not found")
|
|
)
|
|
|
|
// TransactionFilter contains filter options for listing transactions
|
|
type TransactionFilter struct {
|
|
// Date range filters
|
|
StartDate *time.Time
|
|
EndDate *time.Time
|
|
|
|
// Entity filters
|
|
CategoryID *uint
|
|
AccountID *uint
|
|
TagIDs []uint
|
|
Type *models.TransactionType
|
|
Currency *models.Currency
|
|
RecurringID *uint
|
|
UserID *uint
|
|
|
|
// Search
|
|
NoteSearch string
|
|
}
|
|
|
|
// TransactionSort defines sorting options
|
|
type TransactionSort struct {
|
|
Field string // "transaction_date", "amount", "created_at"
|
|
Ascending bool
|
|
}
|
|
|
|
// TransactionListOptions contains options for listing transactions
|
|
type TransactionListOptions struct {
|
|
Filter TransactionFilter
|
|
Sort TransactionSort
|
|
Offset int
|
|
Limit int
|
|
}
|
|
|
|
// TransactionListResult contains the result of a paginated transaction list query
|
|
type TransactionListResult struct {
|
|
Transactions []models.Transaction
|
|
Total int64
|
|
Offset int
|
|
Limit int
|
|
}
|
|
|
|
// TransactionRepository handles database operations for transactions
|
|
type TransactionRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewTransactionRepository creates a new TransactionRepository instance
|
|
func NewTransactionRepository(db *gorm.DB) *TransactionRepository {
|
|
return &TransactionRepository{db: db}
|
|
}
|
|
|
|
// Create creates a new transaction in the database
|
|
func (r *TransactionRepository) Create(transaction *models.Transaction) error {
|
|
if err := r.db.Create(transaction).Error; err != nil {
|
|
return fmt.Errorf("failed to create transaction: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateWithTags creates a new transaction with associated tags
|
|
func (r *TransactionRepository) CreateWithTags(transaction *models.Transaction, tagIDs []uint) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// Create the transaction
|
|
if err := tx.Create(transaction).Error; err != nil {
|
|
return fmt.Errorf("failed to create transaction: %w", err)
|
|
}
|
|
|
|
// Add tags
|
|
for _, tagID := range tagIDs {
|
|
transactionTag := models.TransactionTag{
|
|
TransactionID: transaction.ID,
|
|
TagID: tagID,
|
|
}
|
|
if err := tx.Create(&transactionTag).Error; err != nil {
|
|
return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// GetByID retrieves a transaction by its ID
|
|
func (r *TransactionRepository) GetByID(userID uint, id uint) (*models.Transaction, error) {
|
|
var transaction models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrTransactionNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get transaction: %w", err)
|
|
}
|
|
return &transaction, nil
|
|
}
|
|
|
|
// GetByIDWithRelations retrieves a transaction by its ID with all relations preloaded
|
|
func (r *TransactionRepository) GetByIDWithRelations(userID uint, id uint) (*models.Transaction, error) {
|
|
var transaction models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).
|
|
Preload("Category").
|
|
Preload("Account").
|
|
Preload("ToAccount").
|
|
Preload("Tags").
|
|
Preload("Recurring").
|
|
First(&transaction, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrTransactionNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get transaction with relations: %w", err)
|
|
}
|
|
return &transaction, nil
|
|
}
|
|
|
|
// Update updates an existing transaction in the database
|
|
func (r *TransactionRepository) Update(transaction *models.Transaction) error {
|
|
// First check if the transaction exists
|
|
var existing models.Transaction
|
|
if err := r.db.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrTransactionNotFound
|
|
}
|
|
return fmt.Errorf("failed to check transaction existence: %w", err)
|
|
}
|
|
|
|
// Update the transaction
|
|
if err := r.db.Save(transaction).Error; err != nil {
|
|
return fmt.Errorf("failed to update transaction: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateWithTags updates a transaction and its associated tags
|
|
func (r *TransactionRepository) UpdateWithTags(transaction *models.Transaction, tagIDs []uint) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// First check if the transaction exists
|
|
var existing models.Transaction
|
|
if err := tx.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrTransactionNotFound
|
|
}
|
|
return fmt.Errorf("failed to check transaction existence: %w", err)
|
|
}
|
|
|
|
// Update the transaction
|
|
if err := tx.Save(transaction).Error; err != nil {
|
|
return fmt.Errorf("failed to update transaction: %w", err)
|
|
}
|
|
|
|
// Clear existing tags
|
|
if err := tx.Where("transaction_id = ?", transaction.ID).Delete(&models.TransactionTag{}).Error; err != nil {
|
|
return fmt.Errorf("failed to clear existing tags: %w", err)
|
|
}
|
|
|
|
// Add new tags
|
|
for _, tagID := range tagIDs {
|
|
transactionTag := models.TransactionTag{
|
|
TransactionID: transaction.ID,
|
|
TagID: tagID,
|
|
}
|
|
if err := tx.Create(&transactionTag).Error; err != nil {
|
|
return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// Delete deletes a transaction by its ID (soft delete)
|
|
func (r *TransactionRepository) Delete(userID uint, id uint) error {
|
|
// First check if the transaction exists
|
|
var transaction models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return ErrTransactionNotFound
|
|
}
|
|
return fmt.Errorf("failed to check transaction existence: %w", err)
|
|
}
|
|
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// Delete associated tags
|
|
if err := tx.Where("transaction_id = ?", id).Delete(&models.TransactionTag{}).Error; err != nil {
|
|
return fmt.Errorf("failed to delete transaction tags: %w", err)
|
|
}
|
|
|
|
// Delete the transaction (soft delete due to gorm.DeletedAt field)
|
|
if err := tx.Delete(&transaction).Error; err != nil {
|
|
return fmt.Errorf("failed to delete transaction: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// List retrieves transactions with pagination, filtering, and sorting
|
|
func (r *TransactionRepository) List(userID uint, options TransactionListOptions) (*TransactionListResult, error) {
|
|
query := r.db.Model(&models.Transaction{}).Where("user_id = ?", userID)
|
|
|
|
// Apply filters
|
|
query = r.applyFilters(query, options.Filter)
|
|
|
|
// Count total before pagination
|
|
var total int64
|
|
if err := query.Count(&total).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to count transactions: %w", err)
|
|
}
|
|
|
|
// Apply sorting (default: transaction_date DESC)
|
|
query = r.applySorting(query, options.Sort)
|
|
|
|
// Apply pagination
|
|
if options.Limit > 0 {
|
|
query = query.Limit(options.Limit)
|
|
}
|
|
if options.Offset > 0 {
|
|
query = query.Offset(options.Offset)
|
|
}
|
|
|
|
// Preload relations
|
|
query = query.Preload("Category").
|
|
Preload("Account").
|
|
Preload("ToAccount").
|
|
Preload("Tags")
|
|
|
|
// Execute query
|
|
var transactions []models.Transaction
|
|
if err := query.Find(&transactions).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to list transactions: %w", err)
|
|
}
|
|
|
|
return &TransactionListResult{
|
|
Transactions: transactions,
|
|
Total: total,
|
|
Offset: options.Offset,
|
|
Limit: options.Limit,
|
|
}, nil
|
|
}
|
|
|
|
// applyFilters applies filter conditions to the query
|
|
func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionFilter) *gorm.DB {
|
|
// Date range filters
|
|
if filter.StartDate != nil {
|
|
query = query.Where("transaction_date >= ?", filter.StartDate)
|
|
}
|
|
if filter.EndDate != nil {
|
|
query = query.Where("transaction_date <= ?", filter.EndDate)
|
|
}
|
|
|
|
// Entity filters
|
|
if filter.CategoryID != nil {
|
|
query = query.Where("category_id = ?", *filter.CategoryID)
|
|
}
|
|
if filter.AccountID != nil {
|
|
query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID)
|
|
}
|
|
if filter.Type != nil {
|
|
query = query.Where("type = ?", *filter.Type)
|
|
}
|
|
if filter.Currency != nil {
|
|
query = query.Where("currency = ?", *filter.Currency)
|
|
}
|
|
if filter.RecurringID != nil {
|
|
query = query.Where("recurring_id = ?", *filter.RecurringID)
|
|
}
|
|
// UserID provided in argument takes precedence, but if filter has it, we can redundant check or ignore.
|
|
// The caller `List` already applied `Where("user_id = ?", userID)`.
|
|
|
|
// Tag filter - requires subquery
|
|
if len(filter.TagIDs) > 0 {
|
|
query = query.Where("id IN (?)",
|
|
r.db.Model(&models.TransactionTag{}).
|
|
Select("transaction_id").
|
|
Where("tag_id IN ?", filter.TagIDs))
|
|
}
|
|
|
|
// Note search
|
|
if filter.NoteSearch != "" {
|
|
query = query.Where("note LIKE ?", "%"+filter.NoteSearch+"%")
|
|
}
|
|
|
|
return query
|
|
}
|
|
|
|
// applySorting applies sorting to the query
|
|
func (r *TransactionRepository) applySorting(query *gorm.DB, sort TransactionSort) *gorm.DB {
|
|
// Default sorting: transaction_date DESC (newest first)
|
|
if sort.Field == "" {
|
|
return query.Order("transaction_date DESC, created_at DESC")
|
|
}
|
|
|
|
// Validate sort field
|
|
validFields := map[string]bool{
|
|
"transaction_date": true,
|
|
"amount": true,
|
|
"created_at": true,
|
|
}
|
|
|
|
if !validFields[sort.Field] {
|
|
return query.Order("transaction_date DESC, created_at DESC")
|
|
}
|
|
|
|
direction := "DESC"
|
|
if sort.Ascending {
|
|
direction = "ASC"
|
|
}
|
|
|
|
return query.Order(fmt.Sprintf("%s %s", sort.Field, direction))
|
|
}
|
|
|
|
// GetByAccountID retrieves all transactions for a specific account
|
|
func (r *TransactionRepository) GetByAccountID(userID uint, accountID uint) ([]models.Transaction, error) {
|
|
var transactions []models.Transaction
|
|
if err := r.db.Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID).
|
|
Order("transaction_date DESC").
|
|
Preload("Category").
|
|
Preload("Tags").
|
|
Find(&transactions).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get transactions by account: %w", err)
|
|
}
|
|
return transactions, nil
|
|
}
|
|
|
|
// GetByCategoryID retrieves all transactions for a specific category
|
|
func (r *TransactionRepository) GetByCategoryID(userID uint, categoryID uint) ([]models.Transaction, error) {
|
|
var transactions []models.Transaction
|
|
if err := r.db.Where("user_id = ? AND category_id = ?", userID, categoryID).
|
|
Order("transaction_date DESC").
|
|
Preload("Account").
|
|
Preload("Tags").
|
|
Find(&transactions).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get transactions by category: %w", err)
|
|
}
|
|
return transactions, nil
|
|
}
|
|
|
|
// GetByDateRange retrieves all transactions within a date range
|
|
func (r *TransactionRepository) GetByDateRange(userID uint, startDate, endDate time.Time) ([]models.Transaction, error) {
|
|
var transactions []models.Transaction
|
|
if err := r.db.Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate).
|
|
Order("transaction_date DESC").
|
|
Preload("Category").
|
|
Preload("Account").
|
|
Preload("Tags").
|
|
Find(&transactions).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get transactions by date range: %w", err)
|
|
}
|
|
return transactions, nil
|
|
}
|
|
|
|
// GetByTagID retrieves all transactions with a specific tag
|
|
func (r *TransactionRepository) GetByTagID(userID uint, tagID uint) ([]models.Transaction, error) {
|
|
var transactions []models.Transaction
|
|
if err := r.db.Joins("JOIN transaction_tags ON transaction_tags.transaction_id = transactions.id").
|
|
Where("transactions.user_id = ? AND transaction_tags.tag_id = ?", userID, tagID).
|
|
Order("transaction_date DESC").
|
|
Preload("Category").
|
|
Preload("Account").
|
|
Preload("Tags").
|
|
Find(&transactions).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get transactions by tag: %w", err)
|
|
}
|
|
return transactions, nil
|
|
}
|
|
|
|
// GetByRecurringID retrieves all transactions generated from a recurring transaction
|
|
func (r *TransactionRepository) GetByRecurringID(userID uint, recurringID uint) ([]models.Transaction, error) {
|
|
var transactions []models.Transaction
|
|
if err := r.db.Where("user_id = ? AND recurring_id = ?", userID, recurringID).
|
|
Order("transaction_date DESC").
|
|
Preload("Category").
|
|
Preload("Account").
|
|
Preload("Tags").
|
|
Find(&transactions).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get transactions by recurring ID: %w", err)
|
|
}
|
|
return transactions, nil
|
|
}
|
|
|
|
// ExistsByID checks if a transaction with the given ID exists
|
|
func (r *TransactionRepository) ExistsByID(userID uint, id uint) (bool, error) {
|
|
var count int64
|
|
if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil {
|
|
return false, fmt.Errorf("failed to check transaction existence: %w", err)
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
// CountByAccountID returns the count of transactions for an account
|
|
func (r *TransactionRepository) CountByAccountID(userID uint, accountID uint) (int64, error) {
|
|
var count int64
|
|
if err := r.db.Model(&models.Transaction{}).
|
|
Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID).
|
|
Count(&count).Error; err != nil {
|
|
return 0, fmt.Errorf("failed to count transactions by account: %w", err)
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// CountByCategoryID returns the count of transactions for a category
|
|
func (r *TransactionRepository) CountByCategoryID(userID uint, categoryID uint) (int64, error) {
|
|
var count int64
|
|
if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND category_id = ?", userID, categoryID).Count(&count).Error; err != nil {
|
|
return 0, fmt.Errorf("failed to count transactions by category: %w", err)
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
// GetSumByAccountID calculates the sum of transactions for an account by type
|
|
func (r *TransactionRepository) GetSumByAccountID(userID uint, accountID uint, transactionType models.TransactionType) (float64, error) {
|
|
var result struct {
|
|
Total float64
|
|
}
|
|
if err := r.db.Model(&models.Transaction{}).
|
|
Select("COALESCE(SUM(amount), 0) as total").
|
|
Where("user_id = ? AND account_id = ? AND type = ?", userID, accountID, transactionType).
|
|
Scan(&result).Error; err != nil {
|
|
return 0, fmt.Errorf("failed to get sum by account: %w", err)
|
|
}
|
|
return result.Total, nil
|
|
}
|
|
|
|
// GetSumByCategoryID calculates the sum of transactions for a category
|
|
func (r *TransactionRepository) GetSumByCategoryID(userID uint, categoryID uint, startDate, endDate *time.Time) (float64, error) {
|
|
query := r.db.Model(&models.Transaction{}).
|
|
Select("COALESCE(SUM(amount), 0) as total").
|
|
Where("user_id = ? AND category_id = ?", userID, categoryID)
|
|
|
|
if startDate != nil {
|
|
query = query.Where("transaction_date >= ?", startDate)
|
|
}
|
|
if endDate != nil {
|
|
query = query.Where("transaction_date <= ?", endDate)
|
|
}
|
|
|
|
var result struct {
|
|
Total float64
|
|
}
|
|
if err := query.Scan(&result).Error; err != nil {
|
|
return 0, fmt.Errorf("failed to get sum by category: %w", err)
|
|
}
|
|
return result.Total, nil
|
|
}
|
|
|
|
// GetRecentTransactions retrieves the most recent transactions
|
|
func (r *TransactionRepository) GetRecentTransactions(userID uint, limit int) ([]models.Transaction, error) {
|
|
var transactions []models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).
|
|
Order("transaction_date DESC, created_at DESC").
|
|
Limit(limit).
|
|
Preload("Category").
|
|
Preload("Account").
|
|
Preload("Tags").
|
|
Find(&transactions).Error; err != nil {
|
|
return nil, fmt.Errorf("failed to get recent transactions: %w", err)
|
|
}
|
|
return transactions, nil
|
|
}
|
|
|
|
// GetRelatedTransactions retrieves all related transactions for a given transaction ID
|
|
// For an expense transaction: returns its refund income and/or reimbursement income if they exist
|
|
// For a refund/reimbursement income: returns the original expense transaction
|
|
// Feature: accounting-feature-upgrade
|
|
// Validates: Requirements 8.21, 8.22
|
|
func (r *TransactionRepository) GetRelatedTransactions(userID uint, id uint) ([]models.Transaction, error) {
|
|
var relatedTransactions []models.Transaction
|
|
|
|
// First, get the transaction itself to determine its type
|
|
var transaction models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrTransactionNotFound
|
|
}
|
|
return nil, fmt.Errorf("failed to get transaction: %w", err)
|
|
}
|
|
|
|
// Case 1: If this is an expense transaction, find its refund and reimbursement income records
|
|
if transaction.Type == models.TransactionTypeExpense {
|
|
// Find refund income if exists
|
|
if transaction.RefundIncomeID != nil {
|
|
var refundIncome models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).
|
|
Preload("Category").
|
|
Preload("Account").
|
|
First(&refundIncome, *transaction.RefundIncomeID).Error; err == nil {
|
|
relatedTransactions = append(relatedTransactions, refundIncome)
|
|
}
|
|
}
|
|
|
|
// Find reimbursement income if exists
|
|
if transaction.ReimbursementIncomeID != nil {
|
|
var reimbursementIncome models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).
|
|
Preload("Category").
|
|
Preload("Account").
|
|
First(&reimbursementIncome, *transaction.ReimbursementIncomeID).Error; err == nil {
|
|
relatedTransactions = append(relatedTransactions, reimbursementIncome)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Case 2: If this is a refund or reimbursement income, find the original expense transaction
|
|
if transaction.OriginalTransactionID != nil {
|
|
var originalTransaction models.Transaction
|
|
if err := r.db.Where("user_id = ?", userID).
|
|
Preload("Category").
|
|
Preload("Account").
|
|
First(&originalTransaction, *transaction.OriginalTransactionID).Error; err == nil {
|
|
relatedTransactions = append(relatedTransactions, originalTransaction)
|
|
}
|
|
}
|
|
|
|
return relatedTransactions, nil
|
|
}
|