Files
Novault-backend/internal/repository/transaction_repository.go
2026-01-25 21:59:00 +08:00

529 lines
17 KiB
Go

package repository
import (
"errors"
"fmt"
"time"
"accounting-app/internal/models"
"gorm.io/gorm"
)
// Transaction repository errors
var (
ErrTransactionNotFound = errors.New("transaction not found")
)
// TransactionFilter contains filter options for listing transactions
type TransactionFilter struct {
// Date range filters
StartDate *time.Time
EndDate *time.Time
// Entity filters
CategoryID *uint
AccountID *uint
TagIDs []uint
Type *models.TransactionType
Currency *models.Currency
RecurringID *uint
UserID *uint
// Search
NoteSearch string
}
// TransactionSort defines sorting options
type TransactionSort struct {
Field string // "transaction_date", "amount", "created_at"
Ascending bool
}
// TransactionListOptions contains options for listing transactions
type TransactionListOptions struct {
Filter TransactionFilter
Sort TransactionSort
Offset int
Limit int
}
// TransactionListResult contains the result of a paginated transaction list query
type TransactionListResult struct {
Transactions []models.Transaction
Total int64
Offset int
Limit int
}
// TransactionRepository handles database operations for transactions
type TransactionRepository struct {
db *gorm.DB
}
// NewTransactionRepository creates a new TransactionRepository instance
func NewTransactionRepository(db *gorm.DB) *TransactionRepository {
return &TransactionRepository{db: db}
}
// Create creates a new transaction in the database
func (r *TransactionRepository) Create(transaction *models.Transaction) error {
if err := r.db.Create(transaction).Error; err != nil {
return fmt.Errorf("failed to create transaction: %w", err)
}
return nil
}
// CreateWithTags creates a new transaction with associated tags
func (r *TransactionRepository) CreateWithTags(transaction *models.Transaction, tagIDs []uint) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// Create the transaction
if err := tx.Create(transaction).Error; err != nil {
return fmt.Errorf("failed to create transaction: %w", err)
}
// Add tags
for _, tagID := range tagIDs {
transactionTag := models.TransactionTag{
TransactionID: transaction.ID,
TagID: tagID,
}
if err := tx.Create(&transactionTag).Error; err != nil {
return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err)
}
}
return nil
})
}
// GetByID retrieves a transaction by its ID
func (r *TransactionRepository) GetByID(userID uint, id uint) (*models.Transaction, error) {
var transaction models.Transaction
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTransactionNotFound
}
return nil, fmt.Errorf("failed to get transaction: %w", err)
}
return &transaction, nil
}
// GetByIDWithRelations retrieves a transaction by its ID with all relations preloaded
func (r *TransactionRepository) GetByIDWithRelations(userID uint, id uint) (*models.Transaction, error) {
var transaction models.Transaction
if err := r.db.Where("user_id = ?", userID).
Preload("Category").
Preload("Account").
Preload("ToAccount").
Preload("Tags").
Preload("Recurring").
First(&transaction, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTransactionNotFound
}
return nil, fmt.Errorf("failed to get transaction with relations: %w", err)
}
return &transaction, nil
}
// Update updates an existing transaction in the database
func (r *TransactionRepository) Update(transaction *models.Transaction) error {
// First check if the transaction exists
var existing models.Transaction
if err := r.db.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrTransactionNotFound
}
return fmt.Errorf("failed to check transaction existence: %w", err)
}
// Update the transaction
if err := r.db.Save(transaction).Error; err != nil {
return fmt.Errorf("failed to update transaction: %w", err)
}
return nil
}
// UpdateWithTags updates a transaction and its associated tags
func (r *TransactionRepository) UpdateWithTags(transaction *models.Transaction, tagIDs []uint) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// First check if the transaction exists
var existing models.Transaction
if err := tx.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrTransactionNotFound
}
return fmt.Errorf("failed to check transaction existence: %w", err)
}
// Update the transaction
if err := tx.Save(transaction).Error; err != nil {
return fmt.Errorf("failed to update transaction: %w", err)
}
// Clear existing tags
if err := tx.Where("transaction_id = ?", transaction.ID).Delete(&models.TransactionTag{}).Error; err != nil {
return fmt.Errorf("failed to clear existing tags: %w", err)
}
// Add new tags
for _, tagID := range tagIDs {
transactionTag := models.TransactionTag{
TransactionID: transaction.ID,
TagID: tagID,
}
if err := tx.Create(&transactionTag).Error; err != nil {
return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err)
}
}
return nil
})
}
// Delete deletes a transaction by its ID (soft delete)
func (r *TransactionRepository) Delete(userID uint, id uint) error {
// First check if the transaction exists
var transaction models.Transaction
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrTransactionNotFound
}
return fmt.Errorf("failed to check transaction existence: %w", err)
}
return r.db.Transaction(func(tx *gorm.DB) error {
// Delete associated tags
if err := tx.Where("transaction_id = ?", id).Delete(&models.TransactionTag{}).Error; err != nil {
return fmt.Errorf("failed to delete transaction tags: %w", err)
}
// Delete the transaction (soft delete due to gorm.DeletedAt field)
if err := tx.Delete(&transaction).Error; err != nil {
return fmt.Errorf("failed to delete transaction: %w", err)
}
return nil
})
}
// List retrieves transactions with pagination, filtering, and sorting
func (r *TransactionRepository) List(userID uint, options TransactionListOptions) (*TransactionListResult, error) {
query := r.db.Model(&models.Transaction{}).Where("user_id = ?", userID)
// Apply filters
query = r.applyFilters(query, options.Filter)
// Count total before pagination
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, fmt.Errorf("failed to count transactions: %w", err)
}
// Apply sorting (default: transaction_date DESC)
query = r.applySorting(query, options.Sort)
// Apply pagination
if options.Limit > 0 {
query = query.Limit(options.Limit)
}
if options.Offset > 0 {
query = query.Offset(options.Offset)
}
// Preload relations
query = query.Preload("Category").
Preload("Account").
Preload("ToAccount").
Preload("Tags")
// Execute query
var transactions []models.Transaction
if err := query.Find(&transactions).Error; err != nil {
return nil, fmt.Errorf("failed to list transactions: %w", err)
}
return &TransactionListResult{
Transactions: transactions,
Total: total,
Offset: options.Offset,
Limit: options.Limit,
}, nil
}
// applyFilters applies filter conditions to the query
func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionFilter) *gorm.DB {
// Date range filters
if filter.StartDate != nil {
query = query.Where("transaction_date >= ?", filter.StartDate)
}
if filter.EndDate != nil {
query = query.Where("transaction_date <= ?", filter.EndDate)
}
// Entity filters
if filter.CategoryID != nil {
query = query.Where("category_id = ?", *filter.CategoryID)
}
if filter.AccountID != nil {
query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID)
}
if filter.Type != nil {
query = query.Where("type = ?", *filter.Type)
}
if filter.Currency != nil {
query = query.Where("currency = ?", *filter.Currency)
}
if filter.RecurringID != nil {
query = query.Where("recurring_id = ?", *filter.RecurringID)
}
// UserID provided in argument takes precedence, but if filter has it, we can redundant check or ignore.
// The caller `List` already applied `Where("user_id = ?", userID)`.
// Tag filter - requires subquery
if len(filter.TagIDs) > 0 {
query = query.Where("id IN (?)",
r.db.Model(&models.TransactionTag{}).
Select("transaction_id").
Where("tag_id IN ?", filter.TagIDs))
}
// Note search
if filter.NoteSearch != "" {
query = query.Where("note LIKE ?", "%"+filter.NoteSearch+"%")
}
return query
}
// applySorting applies sorting to the query
func (r *TransactionRepository) applySorting(query *gorm.DB, sort TransactionSort) *gorm.DB {
// Default sorting: transaction_date DESC (newest first)
if sort.Field == "" {
return query.Order("transaction_date DESC, created_at DESC")
}
// Validate sort field
validFields := map[string]bool{
"transaction_date": true,
"amount": true,
"created_at": true,
}
if !validFields[sort.Field] {
return query.Order("transaction_date DESC, created_at DESC")
}
direction := "DESC"
if sort.Ascending {
direction = "ASC"
}
return query.Order(fmt.Sprintf("%s %s", sort.Field, direction))
}
// GetByAccountID retrieves all transactions for a specific account
func (r *TransactionRepository) GetByAccountID(userID uint, accountID uint) ([]models.Transaction, error) {
var transactions []models.Transaction
if err := r.db.Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID).
Order("transaction_date DESC").
Preload("Category").
Preload("Tags").
Find(&transactions).Error; err != nil {
return nil, fmt.Errorf("failed to get transactions by account: %w", err)
}
return transactions, nil
}
// GetByCategoryID retrieves all transactions for a specific category
func (r *TransactionRepository) GetByCategoryID(userID uint, categoryID uint) ([]models.Transaction, error) {
var transactions []models.Transaction
if err := r.db.Where("user_id = ? AND category_id = ?", userID, categoryID).
Order("transaction_date DESC").
Preload("Account").
Preload("Tags").
Find(&transactions).Error; err != nil {
return nil, fmt.Errorf("failed to get transactions by category: %w", err)
}
return transactions, nil
}
// GetByDateRange retrieves all transactions within a date range
func (r *TransactionRepository) GetByDateRange(userID uint, startDate, endDate time.Time) ([]models.Transaction, error) {
var transactions []models.Transaction
if err := r.db.Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate).
Order("transaction_date DESC").
Preload("Category").
Preload("Account").
Preload("Tags").
Find(&transactions).Error; err != nil {
return nil, fmt.Errorf("failed to get transactions by date range: %w", err)
}
return transactions, nil
}
// GetByTagID retrieves all transactions with a specific tag
func (r *TransactionRepository) GetByTagID(userID uint, tagID uint) ([]models.Transaction, error) {
var transactions []models.Transaction
if err := r.db.Joins("JOIN transaction_tags ON transaction_tags.transaction_id = transactions.id").
Where("transactions.user_id = ? AND transaction_tags.tag_id = ?", userID, tagID).
Order("transaction_date DESC").
Preload("Category").
Preload("Account").
Preload("Tags").
Find(&transactions).Error; err != nil {
return nil, fmt.Errorf("failed to get transactions by tag: %w", err)
}
return transactions, nil
}
// GetByRecurringID retrieves all transactions generated from a recurring transaction
func (r *TransactionRepository) GetByRecurringID(userID uint, recurringID uint) ([]models.Transaction, error) {
var transactions []models.Transaction
if err := r.db.Where("user_id = ? AND recurring_id = ?", userID, recurringID).
Order("transaction_date DESC").
Preload("Category").
Preload("Account").
Preload("Tags").
Find(&transactions).Error; err != nil {
return nil, fmt.Errorf("failed to get transactions by recurring ID: %w", err)
}
return transactions, nil
}
// ExistsByID checks if a transaction with the given ID exists
func (r *TransactionRepository) ExistsByID(userID uint, id uint) (bool, error) {
var count int64
if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil {
return false, fmt.Errorf("failed to check transaction existence: %w", err)
}
return count > 0, nil
}
// CountByAccountID returns the count of transactions for an account
func (r *TransactionRepository) CountByAccountID(userID uint, accountID uint) (int64, error) {
var count int64
if err := r.db.Model(&models.Transaction{}).
Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID).
Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to count transactions by account: %w", err)
}
return count, nil
}
// CountByCategoryID returns the count of transactions for a category
func (r *TransactionRepository) CountByCategoryID(userID uint, categoryID uint) (int64, error) {
var count int64
if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND category_id = ?", userID, categoryID).Count(&count).Error; err != nil {
return 0, fmt.Errorf("failed to count transactions by category: %w", err)
}
return count, nil
}
// GetSumByAccountID calculates the sum of transactions for an account by type
func (r *TransactionRepository) GetSumByAccountID(userID uint, accountID uint, transactionType models.TransactionType) (float64, error) {
var result struct {
Total float64
}
if err := r.db.Model(&models.Transaction{}).
Select("COALESCE(SUM(amount), 0) as total").
Where("user_id = ? AND account_id = ? AND type = ?", userID, accountID, transactionType).
Scan(&result).Error; err != nil {
return 0, fmt.Errorf("failed to get sum by account: %w", err)
}
return result.Total, nil
}
// GetSumByCategoryID calculates the sum of transactions for a category
func (r *TransactionRepository) GetSumByCategoryID(userID uint, categoryID uint, startDate, endDate *time.Time) (float64, error) {
query := r.db.Model(&models.Transaction{}).
Select("COALESCE(SUM(amount), 0) as total").
Where("user_id = ? AND category_id = ?", userID, categoryID)
if startDate != nil {
query = query.Where("transaction_date >= ?", startDate)
}
if endDate != nil {
query = query.Where("transaction_date <= ?", endDate)
}
var result struct {
Total float64
}
if err := query.Scan(&result).Error; err != nil {
return 0, fmt.Errorf("failed to get sum by category: %w", err)
}
return result.Total, nil
}
// GetRecentTransactions retrieves the most recent transactions
func (r *TransactionRepository) GetRecentTransactions(userID uint, limit int) ([]models.Transaction, error) {
var transactions []models.Transaction
if err := r.db.Where("user_id = ?", userID).
Order("transaction_date DESC, created_at DESC").
Limit(limit).
Preload("Category").
Preload("Account").
Preload("Tags").
Find(&transactions).Error; err != nil {
return nil, fmt.Errorf("failed to get recent transactions: %w", err)
}
return transactions, nil
}
// GetRelatedTransactions retrieves all related transactions for a given transaction ID
// For an expense transaction: returns its refund income and/or reimbursement income if they exist
// For a refund/reimbursement income: returns the original expense transaction
// Feature: accounting-feature-upgrade
// Validates: Requirements 8.21, 8.22
func (r *TransactionRepository) GetRelatedTransactions(userID uint, id uint) ([]models.Transaction, error) {
var relatedTransactions []models.Transaction
// First, get the transaction itself to determine its type
var transaction models.Transaction
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTransactionNotFound
}
return nil, fmt.Errorf("failed to get transaction: %w", err)
}
// Case 1: If this is an expense transaction, find its refund and reimbursement income records
if transaction.Type == models.TransactionTypeExpense {
// Find refund income if exists
if transaction.RefundIncomeID != nil {
var refundIncome models.Transaction
if err := r.db.Where("user_id = ?", userID).
Preload("Category").
Preload("Account").
First(&refundIncome, *transaction.RefundIncomeID).Error; err == nil {
relatedTransactions = append(relatedTransactions, refundIncome)
}
}
// Find reimbursement income if exists
if transaction.ReimbursementIncomeID != nil {
var reimbursementIncome models.Transaction
if err := r.db.Where("user_id = ?", userID).
Preload("Category").
Preload("Account").
First(&reimbursementIncome, *transaction.ReimbursementIncomeID).Error; err == nil {
relatedTransactions = append(relatedTransactions, reimbursementIncome)
}
}
}
// Case 2: If this is a refund or reimbursement income, find the original expense transaction
if transaction.OriginalTransactionID != nil {
var originalTransaction models.Transaction
if err := r.db.Where("user_id = ?", userID).
Preload("Category").
Preload("Account").
First(&originalTransaction, *transaction.OriginalTransactionID).Error; err == nil {
relatedTransactions = append(relatedTransactions, originalTransaction)
}
}
return relatedTransactions, nil
}