package repository import ( "errors" "fmt" "time" "accounting-app/internal/models" "gorm.io/gorm" ) // Transaction repository errors var ( ErrTransactionNotFound = errors.New("transaction not found") ) // TransactionFilter contains filter options for listing transactions type TransactionFilter struct { // Date range filters StartDate *time.Time EndDate *time.Time // Entity filters CategoryID *uint AccountID *uint TagIDs []uint Type *models.TransactionType Currency *models.Currency RecurringID *uint UserID *uint // Search NoteSearch string } // TransactionSort defines sorting options type TransactionSort struct { Field string // "transaction_date", "amount", "created_at" Ascending bool } // TransactionListOptions contains options for listing transactions type TransactionListOptions struct { Filter TransactionFilter Sort TransactionSort Offset int Limit int } // TransactionListResult contains the result of a paginated transaction list query type TransactionListResult struct { Transactions []models.Transaction Total int64 Offset int Limit int } // TransactionRepository handles database operations for transactions type TransactionRepository struct { db *gorm.DB } // NewTransactionRepository creates a new TransactionRepository instance func NewTransactionRepository(db *gorm.DB) *TransactionRepository { return &TransactionRepository{db: db} } // Create creates a new transaction in the database func (r *TransactionRepository) Create(transaction *models.Transaction) error { if err := r.db.Create(transaction).Error; err != nil { return fmt.Errorf("failed to create transaction: %w", err) } return nil } // CreateWithTags creates a new transaction with associated tags func (r *TransactionRepository) CreateWithTags(transaction *models.Transaction, tagIDs []uint) error { return r.db.Transaction(func(tx *gorm.DB) error { // Create the transaction if err := tx.Create(transaction).Error; err != nil { return fmt.Errorf("failed to create transaction: %w", err) } // Add tags for _, tagID := range tagIDs { transactionTag := models.TransactionTag{ TransactionID: transaction.ID, TagID: tagID, } if err := tx.Create(&transactionTag).Error; err != nil { return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err) } } return nil }) } // GetByID retrieves a transaction by its ID func (r *TransactionRepository) GetByID(userID uint, id uint) (*models.Transaction, error) { var transaction models.Transaction if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrTransactionNotFound } return nil, fmt.Errorf("failed to get transaction: %w", err) } return &transaction, nil } // GetByIDWithRelations retrieves a transaction by its ID with all relations preloaded func (r *TransactionRepository) GetByIDWithRelations(userID uint, id uint) (*models.Transaction, error) { var transaction models.Transaction if err := r.db.Where("user_id = ?", userID). Preload("Category"). Preload("Account"). Preload("ToAccount"). Preload("Tags"). Preload("Recurring"). First(&transaction, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrTransactionNotFound } return nil, fmt.Errorf("failed to get transaction with relations: %w", err) } return &transaction, nil } // Update updates an existing transaction in the database func (r *TransactionRepository) Update(transaction *models.Transaction) error { // First check if the transaction exists var existing models.Transaction if err := r.db.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrTransactionNotFound } return fmt.Errorf("failed to check transaction existence: %w", err) } // Update the transaction if err := r.db.Save(transaction).Error; err != nil { return fmt.Errorf("failed to update transaction: %w", err) } return nil } // UpdateWithTags updates a transaction and its associated tags func (r *TransactionRepository) UpdateWithTags(transaction *models.Transaction, tagIDs []uint) error { return r.db.Transaction(func(tx *gorm.DB) error { // First check if the transaction exists var existing models.Transaction if err := tx.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrTransactionNotFound } return fmt.Errorf("failed to check transaction existence: %w", err) } // Update the transaction if err := tx.Save(transaction).Error; err != nil { return fmt.Errorf("failed to update transaction: %w", err) } // Clear existing tags if err := tx.Where("transaction_id = ?", transaction.ID).Delete(&models.TransactionTag{}).Error; err != nil { return fmt.Errorf("failed to clear existing tags: %w", err) } // Add new tags for _, tagID := range tagIDs { transactionTag := models.TransactionTag{ TransactionID: transaction.ID, TagID: tagID, } if err := tx.Create(&transactionTag).Error; err != nil { return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err) } } return nil }) } // Delete deletes a transaction by its ID (soft delete) func (r *TransactionRepository) Delete(userID uint, id uint) error { // First check if the transaction exists var transaction models.Transaction if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrTransactionNotFound } return fmt.Errorf("failed to check transaction existence: %w", err) } return r.db.Transaction(func(tx *gorm.DB) error { // Delete associated tags if err := tx.Where("transaction_id = ?", id).Delete(&models.TransactionTag{}).Error; err != nil { return fmt.Errorf("failed to delete transaction tags: %w", err) } // Delete the transaction (soft delete due to gorm.DeletedAt field) if err := tx.Delete(&transaction).Error; err != nil { return fmt.Errorf("failed to delete transaction: %w", err) } return nil }) } // List retrieves transactions with pagination, filtering, and sorting func (r *TransactionRepository) List(userID uint, options TransactionListOptions) (*TransactionListResult, error) { query := r.db.Model(&models.Transaction{}).Where("user_id = ?", userID) // Apply filters query = r.applyFilters(query, options.Filter) // Count total before pagination var total int64 if err := query.Count(&total).Error; err != nil { return nil, fmt.Errorf("failed to count transactions: %w", err) } // Apply sorting (default: transaction_date DESC) query = r.applySorting(query, options.Sort) // Apply pagination if options.Limit > 0 { query = query.Limit(options.Limit) } if options.Offset > 0 { query = query.Offset(options.Offset) } // Preload relations query = query.Preload("Category"). Preload("Account"). Preload("ToAccount"). Preload("Tags") // Execute query var transactions []models.Transaction if err := query.Find(&transactions).Error; err != nil { return nil, fmt.Errorf("failed to list transactions: %w", err) } return &TransactionListResult{ Transactions: transactions, Total: total, Offset: options.Offset, Limit: options.Limit, }, nil } // applyFilters applies filter conditions to the query func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionFilter) *gorm.DB { // Date range filters if filter.StartDate != nil { query = query.Where("transaction_date >= ?", filter.StartDate) } if filter.EndDate != nil { query = query.Where("transaction_date <= ?", filter.EndDate) } // Entity filters if filter.CategoryID != nil { query = query.Where("category_id = ?", *filter.CategoryID) } if filter.AccountID != nil { query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID) } if filter.Type != nil { query = query.Where("type = ?", *filter.Type) } if filter.Currency != nil { query = query.Where("currency = ?", *filter.Currency) } if filter.RecurringID != nil { query = query.Where("recurring_id = ?", *filter.RecurringID) } // UserID provided in argument takes precedence, but if filter has it, we can redundant check or ignore. // The caller `List` already applied `Where("user_id = ?", userID)`. // Tag filter - requires subquery if len(filter.TagIDs) > 0 { query = query.Where("id IN (?)", r.db.Model(&models.TransactionTag{}). Select("transaction_id"). Where("tag_id IN ?", filter.TagIDs)) } // Note search if filter.NoteSearch != "" { query = query.Where("note LIKE ?", "%"+filter.NoteSearch+"%") } return query } // applySorting applies sorting to the query func (r *TransactionRepository) applySorting(query *gorm.DB, sort TransactionSort) *gorm.DB { // Default sorting: transaction_date DESC (newest first) if sort.Field == "" { return query.Order("transaction_date DESC, created_at DESC") } // Validate sort field validFields := map[string]bool{ "transaction_date": true, "amount": true, "created_at": true, } if !validFields[sort.Field] { return query.Order("transaction_date DESC, created_at DESC") } direction := "DESC" if sort.Ascending { direction = "ASC" } return query.Order(fmt.Sprintf("%s %s", sort.Field, direction)) } // GetByAccountID retrieves all transactions for a specific account func (r *TransactionRepository) GetByAccountID(userID uint, accountID uint) ([]models.Transaction, error) { var transactions []models.Transaction if err := r.db.Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID). Order("transaction_date DESC"). Preload("Category"). Preload("Tags"). Find(&transactions).Error; err != nil { return nil, fmt.Errorf("failed to get transactions by account: %w", err) } return transactions, nil } // GetByCategoryID retrieves all transactions for a specific category func (r *TransactionRepository) GetByCategoryID(userID uint, categoryID uint) ([]models.Transaction, error) { var transactions []models.Transaction if err := r.db.Where("user_id = ? AND category_id = ?", userID, categoryID). Order("transaction_date DESC"). Preload("Account"). Preload("Tags"). Find(&transactions).Error; err != nil { return nil, fmt.Errorf("failed to get transactions by category: %w", err) } return transactions, nil } // GetByDateRange retrieves all transactions within a date range func (r *TransactionRepository) GetByDateRange(userID uint, startDate, endDate time.Time) ([]models.Transaction, error) { var transactions []models.Transaction if err := r.db.Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate). Order("transaction_date DESC"). Preload("Category"). Preload("Account"). Preload("Tags"). Find(&transactions).Error; err != nil { return nil, fmt.Errorf("failed to get transactions by date range: %w", err) } return transactions, nil } // GetByTagID retrieves all transactions with a specific tag func (r *TransactionRepository) GetByTagID(userID uint, tagID uint) ([]models.Transaction, error) { var transactions []models.Transaction if err := r.db.Joins("JOIN transaction_tags ON transaction_tags.transaction_id = transactions.id"). Where("transactions.user_id = ? AND transaction_tags.tag_id = ?", userID, tagID). Order("transaction_date DESC"). Preload("Category"). Preload("Account"). Preload("Tags"). Find(&transactions).Error; err != nil { return nil, fmt.Errorf("failed to get transactions by tag: %w", err) } return transactions, nil } // GetByRecurringID retrieves all transactions generated from a recurring transaction func (r *TransactionRepository) GetByRecurringID(userID uint, recurringID uint) ([]models.Transaction, error) { var transactions []models.Transaction if err := r.db.Where("user_id = ? AND recurring_id = ?", userID, recurringID). Order("transaction_date DESC"). Preload("Category"). Preload("Account"). Preload("Tags"). Find(&transactions).Error; err != nil { return nil, fmt.Errorf("failed to get transactions by recurring ID: %w", err) } return transactions, nil } // ExistsByID checks if a transaction with the given ID exists func (r *TransactionRepository) ExistsByID(userID uint, id uint) (bool, error) { var count int64 if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil { return false, fmt.Errorf("failed to check transaction existence: %w", err) } return count > 0, nil } // CountByAccountID returns the count of transactions for an account func (r *TransactionRepository) CountByAccountID(userID uint, accountID uint) (int64, error) { var count int64 if err := r.db.Model(&models.Transaction{}). Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID). Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to count transactions by account: %w", err) } return count, nil } // CountByCategoryID returns the count of transactions for a category func (r *TransactionRepository) CountByCategoryID(userID uint, categoryID uint) (int64, error) { var count int64 if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND category_id = ?", userID, categoryID).Count(&count).Error; err != nil { return 0, fmt.Errorf("failed to count transactions by category: %w", err) } return count, nil } // GetSumByAccountID calculates the sum of transactions for an account by type func (r *TransactionRepository) GetSumByAccountID(userID uint, accountID uint, transactionType models.TransactionType) (float64, error) { var result struct { Total float64 } if err := r.db.Model(&models.Transaction{}). Select("COALESCE(SUM(amount), 0) as total"). Where("user_id = ? AND account_id = ? AND type = ?", userID, accountID, transactionType). Scan(&result).Error; err != nil { return 0, fmt.Errorf("failed to get sum by account: %w", err) } return result.Total, nil } // GetSumByCategoryID calculates the sum of transactions for a category func (r *TransactionRepository) GetSumByCategoryID(userID uint, categoryID uint, startDate, endDate *time.Time) (float64, error) { query := r.db.Model(&models.Transaction{}). Select("COALESCE(SUM(amount), 0) as total"). Where("user_id = ? AND category_id = ?", userID, categoryID) if startDate != nil { query = query.Where("transaction_date >= ?", startDate) } if endDate != nil { query = query.Where("transaction_date <= ?", endDate) } var result struct { Total float64 } if err := query.Scan(&result).Error; err != nil { return 0, fmt.Errorf("failed to get sum by category: %w", err) } return result.Total, nil } // GetRecentTransactions retrieves the most recent transactions func (r *TransactionRepository) GetRecentTransactions(userID uint, limit int) ([]models.Transaction, error) { var transactions []models.Transaction if err := r.db.Where("user_id = ?", userID). Order("transaction_date DESC, created_at DESC"). Limit(limit). Preload("Category"). Preload("Account"). Preload("Tags"). Find(&transactions).Error; err != nil { return nil, fmt.Errorf("failed to get recent transactions: %w", err) } return transactions, nil } // GetRelatedTransactions retrieves all related transactions for a given transaction ID // For an expense transaction: returns its refund income and/or reimbursement income if they exist // For a refund/reimbursement income: returns the original expense transaction // Feature: accounting-feature-upgrade // Validates: Requirements 8.21, 8.22 func (r *TransactionRepository) GetRelatedTransactions(userID uint, id uint) ([]models.Transaction, error) { var relatedTransactions []models.Transaction // First, get the transaction itself to determine its type var transaction models.Transaction if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrTransactionNotFound } return nil, fmt.Errorf("failed to get transaction: %w", err) } // Case 1: If this is an expense transaction, find its refund and reimbursement income records if transaction.Type == models.TransactionTypeExpense { // Find refund income if exists if transaction.RefundIncomeID != nil { var refundIncome models.Transaction if err := r.db.Where("user_id = ?", userID). Preload("Category"). Preload("Account"). First(&refundIncome, *transaction.RefundIncomeID).Error; err == nil { relatedTransactions = append(relatedTransactions, refundIncome) } } // Find reimbursement income if exists if transaction.ReimbursementIncomeID != nil { var reimbursementIncome models.Transaction if err := r.db.Where("user_id = ?", userID). Preload("Category"). Preload("Account"). First(&reimbursementIncome, *transaction.ReimbursementIncomeID).Error; err == nil { relatedTransactions = append(relatedTransactions, reimbursementIncome) } } } // Case 2: If this is a refund or reimbursement income, find the original expense transaction if transaction.OriginalTransactionID != nil { var originalTransaction models.Transaction if err := r.db.Where("user_id = ?", userID). Preload("Category"). Preload("Account"). First(&originalTransaction, *transaction.OriginalTransactionID).Error; err == nil { relatedTransactions = append(relatedTransactions, originalTransaction) } } return relatedTransactions, nil }