init
This commit is contained in:
528
internal/repository/transaction_repository.go
Normal file
528
internal/repository/transaction_repository.go
Normal file
@@ -0,0 +1,528 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"accounting-app/internal/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Transaction repository errors
|
||||
var (
|
||||
ErrTransactionNotFound = errors.New("transaction not found")
|
||||
)
|
||||
|
||||
// TransactionFilter contains filter options for listing transactions
|
||||
type TransactionFilter struct {
|
||||
// Date range filters
|
||||
StartDate *time.Time
|
||||
EndDate *time.Time
|
||||
|
||||
// Entity filters
|
||||
CategoryID *uint
|
||||
AccountID *uint
|
||||
TagIDs []uint
|
||||
Type *models.TransactionType
|
||||
Currency *models.Currency
|
||||
RecurringID *uint
|
||||
UserID *uint
|
||||
|
||||
// Search
|
||||
NoteSearch string
|
||||
}
|
||||
|
||||
// TransactionSort defines sorting options
|
||||
type TransactionSort struct {
|
||||
Field string // "transaction_date", "amount", "created_at"
|
||||
Ascending bool
|
||||
}
|
||||
|
||||
// TransactionListOptions contains options for listing transactions
|
||||
type TransactionListOptions struct {
|
||||
Filter TransactionFilter
|
||||
Sort TransactionSort
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// TransactionListResult contains the result of a paginated transaction list query
|
||||
type TransactionListResult struct {
|
||||
Transactions []models.Transaction
|
||||
Total int64
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// TransactionRepository handles database operations for transactions
|
||||
type TransactionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTransactionRepository creates a new TransactionRepository instance
|
||||
func NewTransactionRepository(db *gorm.DB) *TransactionRepository {
|
||||
return &TransactionRepository{db: db}
|
||||
}
|
||||
|
||||
// Create creates a new transaction in the database
|
||||
func (r *TransactionRepository) Create(transaction *models.Transaction) error {
|
||||
if err := r.db.Create(transaction).Error; err != nil {
|
||||
return fmt.Errorf("failed to create transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateWithTags creates a new transaction with associated tags
|
||||
func (r *TransactionRepository) CreateWithTags(transaction *models.Transaction, tagIDs []uint) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Create the transaction
|
||||
if err := tx.Create(transaction).Error; err != nil {
|
||||
return fmt.Errorf("failed to create transaction: %w", err)
|
||||
}
|
||||
|
||||
// Add tags
|
||||
for _, tagID := range tagIDs {
|
||||
transactionTag := models.TransactionTag{
|
||||
TransactionID: transaction.ID,
|
||||
TagID: tagID,
|
||||
}
|
||||
if err := tx.Create(&transactionTag).Error; err != nil {
|
||||
return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetByID retrieves a transaction by its ID
|
||||
func (r *TransactionRepository) GetByID(userID uint, id uint) (*models.Transaction, error) {
|
||||
var transaction models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTransactionNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get transaction: %w", err)
|
||||
}
|
||||
return &transaction, nil
|
||||
}
|
||||
|
||||
// GetByIDWithRelations retrieves a transaction by its ID with all relations preloaded
|
||||
func (r *TransactionRepository) GetByIDWithRelations(userID uint, id uint) (*models.Transaction, error) {
|
||||
var transaction models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
Preload("ToAccount").
|
||||
Preload("Tags").
|
||||
Preload("Recurring").
|
||||
First(&transaction, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTransactionNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get transaction with relations: %w", err)
|
||||
}
|
||||
return &transaction, nil
|
||||
}
|
||||
|
||||
// Update updates an existing transaction in the database
|
||||
func (r *TransactionRepository) Update(transaction *models.Transaction) error {
|
||||
// First check if the transaction exists
|
||||
var existing models.Transaction
|
||||
if err := r.db.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrTransactionNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to check transaction existence: %w", err)
|
||||
}
|
||||
|
||||
// Update the transaction
|
||||
if err := r.db.Save(transaction).Error; err != nil {
|
||||
return fmt.Errorf("failed to update transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateWithTags updates a transaction and its associated tags
|
||||
func (r *TransactionRepository) UpdateWithTags(transaction *models.Transaction, tagIDs []uint) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// First check if the transaction exists
|
||||
var existing models.Transaction
|
||||
if err := tx.Where("user_id = ?", transaction.UserID).First(&existing, transaction.ID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrTransactionNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to check transaction existence: %w", err)
|
||||
}
|
||||
|
||||
// Update the transaction
|
||||
if err := tx.Save(transaction).Error; err != nil {
|
||||
return fmt.Errorf("failed to update transaction: %w", err)
|
||||
}
|
||||
|
||||
// Clear existing tags
|
||||
if err := tx.Where("transaction_id = ?", transaction.ID).Delete(&models.TransactionTag{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to clear existing tags: %w", err)
|
||||
}
|
||||
|
||||
// Add new tags
|
||||
for _, tagID := range tagIDs {
|
||||
transactionTag := models.TransactionTag{
|
||||
TransactionID: transaction.ID,
|
||||
TagID: tagID,
|
||||
}
|
||||
if err := tx.Create(&transactionTag).Error; err != nil {
|
||||
return fmt.Errorf("failed to add tag %d to transaction: %w", tagID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Delete deletes a transaction by its ID (soft delete)
|
||||
func (r *TransactionRepository) Delete(userID uint, id uint) error {
|
||||
// First check if the transaction exists
|
||||
var transaction models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrTransactionNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to check transaction existence: %w", err)
|
||||
}
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete associated tags
|
||||
if err := tx.Where("transaction_id = ?", id).Delete(&models.TransactionTag{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete transaction tags: %w", err)
|
||||
}
|
||||
|
||||
// Delete the transaction (soft delete due to gorm.DeletedAt field)
|
||||
if err := tx.Delete(&transaction).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// List retrieves transactions with pagination, filtering, and sorting
|
||||
func (r *TransactionRepository) List(userID uint, options TransactionListOptions) (*TransactionListResult, error) {
|
||||
query := r.db.Model(&models.Transaction{}).Where("user_id = ?", userID)
|
||||
|
||||
// Apply filters
|
||||
query = r.applyFilters(query, options.Filter)
|
||||
|
||||
// Count total before pagination
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to count transactions: %w", err)
|
||||
}
|
||||
|
||||
// Apply sorting (default: transaction_date DESC)
|
||||
query = r.applySorting(query, options.Sort)
|
||||
|
||||
// Apply pagination
|
||||
if options.Limit > 0 {
|
||||
query = query.Limit(options.Limit)
|
||||
}
|
||||
if options.Offset > 0 {
|
||||
query = query.Offset(options.Offset)
|
||||
}
|
||||
|
||||
// Preload relations
|
||||
query = query.Preload("Category").
|
||||
Preload("Account").
|
||||
Preload("ToAccount").
|
||||
Preload("Tags")
|
||||
|
||||
// Execute query
|
||||
var transactions []models.Transaction
|
||||
if err := query.Find(&transactions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to list transactions: %w", err)
|
||||
}
|
||||
|
||||
return &TransactionListResult{
|
||||
Transactions: transactions,
|
||||
Total: total,
|
||||
Offset: options.Offset,
|
||||
Limit: options.Limit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// applyFilters applies filter conditions to the query
|
||||
func (r *TransactionRepository) applyFilters(query *gorm.DB, filter TransactionFilter) *gorm.DB {
|
||||
// Date range filters
|
||||
if filter.StartDate != nil {
|
||||
query = query.Where("transaction_date >= ?", filter.StartDate)
|
||||
}
|
||||
if filter.EndDate != nil {
|
||||
query = query.Where("transaction_date <= ?", filter.EndDate)
|
||||
}
|
||||
|
||||
// Entity filters
|
||||
if filter.CategoryID != nil {
|
||||
query = query.Where("category_id = ?", *filter.CategoryID)
|
||||
}
|
||||
if filter.AccountID != nil {
|
||||
query = query.Where("account_id = ? OR to_account_id = ?", *filter.AccountID, *filter.AccountID)
|
||||
}
|
||||
if filter.Type != nil {
|
||||
query = query.Where("type = ?", *filter.Type)
|
||||
}
|
||||
if filter.Currency != nil {
|
||||
query = query.Where("currency = ?", *filter.Currency)
|
||||
}
|
||||
if filter.RecurringID != nil {
|
||||
query = query.Where("recurring_id = ?", *filter.RecurringID)
|
||||
}
|
||||
// UserID provided in argument takes precedence, but if filter has it, we can redundant check or ignore.
|
||||
// The caller `List` already applied `Where("user_id = ?", userID)`.
|
||||
|
||||
// Tag filter - requires subquery
|
||||
if len(filter.TagIDs) > 0 {
|
||||
query = query.Where("id IN (?)",
|
||||
r.db.Model(&models.TransactionTag{}).
|
||||
Select("transaction_id").
|
||||
Where("tag_id IN ?", filter.TagIDs))
|
||||
}
|
||||
|
||||
// Note search
|
||||
if filter.NoteSearch != "" {
|
||||
query = query.Where("note LIKE ?", "%"+filter.NoteSearch+"%")
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// applySorting applies sorting to the query
|
||||
func (r *TransactionRepository) applySorting(query *gorm.DB, sort TransactionSort) *gorm.DB {
|
||||
// Default sorting: transaction_date DESC (newest first)
|
||||
if sort.Field == "" {
|
||||
return query.Order("transaction_date DESC, created_at DESC")
|
||||
}
|
||||
|
||||
// Validate sort field
|
||||
validFields := map[string]bool{
|
||||
"transaction_date": true,
|
||||
"amount": true,
|
||||
"created_at": true,
|
||||
}
|
||||
|
||||
if !validFields[sort.Field] {
|
||||
return query.Order("transaction_date DESC, created_at DESC")
|
||||
}
|
||||
|
||||
direction := "DESC"
|
||||
if sort.Ascending {
|
||||
direction = "ASC"
|
||||
}
|
||||
|
||||
return query.Order(fmt.Sprintf("%s %s", sort.Field, direction))
|
||||
}
|
||||
|
||||
// GetByAccountID retrieves all transactions for a specific account
|
||||
func (r *TransactionRepository) GetByAccountID(userID uint, accountID uint) ([]models.Transaction, error) {
|
||||
var transactions []models.Transaction
|
||||
if err := r.db.Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID).
|
||||
Order("transaction_date DESC").
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Find(&transactions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get transactions by account: %w", err)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// GetByCategoryID retrieves all transactions for a specific category
|
||||
func (r *TransactionRepository) GetByCategoryID(userID uint, categoryID uint) ([]models.Transaction, error) {
|
||||
var transactions []models.Transaction
|
||||
if err := r.db.Where("user_id = ? AND category_id = ?", userID, categoryID).
|
||||
Order("transaction_date DESC").
|
||||
Preload("Account").
|
||||
Preload("Tags").
|
||||
Find(&transactions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get transactions by category: %w", err)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// GetByDateRange retrieves all transactions within a date range
|
||||
func (r *TransactionRepository) GetByDateRange(userID uint, startDate, endDate time.Time) ([]models.Transaction, error) {
|
||||
var transactions []models.Transaction
|
||||
if err := r.db.Where("user_id = ? AND transaction_date >= ? AND transaction_date <= ?", userID, startDate, endDate).
|
||||
Order("transaction_date DESC").
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
Preload("Tags").
|
||||
Find(&transactions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get transactions by date range: %w", err)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// GetByTagID retrieves all transactions with a specific tag
|
||||
func (r *TransactionRepository) GetByTagID(userID uint, tagID uint) ([]models.Transaction, error) {
|
||||
var transactions []models.Transaction
|
||||
if err := r.db.Joins("JOIN transaction_tags ON transaction_tags.transaction_id = transactions.id").
|
||||
Where("transactions.user_id = ? AND transaction_tags.tag_id = ?", userID, tagID).
|
||||
Order("transaction_date DESC").
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
Preload("Tags").
|
||||
Find(&transactions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get transactions by tag: %w", err)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// GetByRecurringID retrieves all transactions generated from a recurring transaction
|
||||
func (r *TransactionRepository) GetByRecurringID(userID uint, recurringID uint) ([]models.Transaction, error) {
|
||||
var transactions []models.Transaction
|
||||
if err := r.db.Where("user_id = ? AND recurring_id = ?", userID, recurringID).
|
||||
Order("transaction_date DESC").
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
Preload("Tags").
|
||||
Find(&transactions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get transactions by recurring ID: %w", err)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// ExistsByID checks if a transaction with the given ID exists
|
||||
func (r *TransactionRepository) ExistsByID(userID uint, id uint) (bool, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND id = ?", userID, id).Count(&count).Error; err != nil {
|
||||
return false, fmt.Errorf("failed to check transaction existence: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// CountByAccountID returns the count of transactions for an account
|
||||
func (r *TransactionRepository) CountByAccountID(userID uint, accountID uint) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&models.Transaction{}).
|
||||
Where("user_id = ? AND (account_id = ? OR to_account_id = ?)", userID, accountID, accountID).
|
||||
Count(&count).Error; err != nil {
|
||||
return 0, fmt.Errorf("failed to count transactions by account: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByCategoryID returns the count of transactions for a category
|
||||
func (r *TransactionRepository) CountByCategoryID(userID uint, categoryID uint) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&models.Transaction{}).Where("user_id = ? AND category_id = ?", userID, categoryID).Count(&count).Error; err != nil {
|
||||
return 0, fmt.Errorf("failed to count transactions by category: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetSumByAccountID calculates the sum of transactions for an account by type
|
||||
func (r *TransactionRepository) GetSumByAccountID(userID uint, accountID uint, transactionType models.TransactionType) (float64, error) {
|
||||
var result struct {
|
||||
Total float64
|
||||
}
|
||||
if err := r.db.Model(&models.Transaction{}).
|
||||
Select("COALESCE(SUM(amount), 0) as total").
|
||||
Where("user_id = ? AND account_id = ? AND type = ?", userID, accountID, transactionType).
|
||||
Scan(&result).Error; err != nil {
|
||||
return 0, fmt.Errorf("failed to get sum by account: %w", err)
|
||||
}
|
||||
return result.Total, nil
|
||||
}
|
||||
|
||||
// GetSumByCategoryID calculates the sum of transactions for a category
|
||||
func (r *TransactionRepository) GetSumByCategoryID(userID uint, categoryID uint, startDate, endDate *time.Time) (float64, error) {
|
||||
query := r.db.Model(&models.Transaction{}).
|
||||
Select("COALESCE(SUM(amount), 0) as total").
|
||||
Where("user_id = ? AND category_id = ?", userID, categoryID)
|
||||
|
||||
if startDate != nil {
|
||||
query = query.Where("transaction_date >= ?", startDate)
|
||||
}
|
||||
if endDate != nil {
|
||||
query = query.Where("transaction_date <= ?", endDate)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Total float64
|
||||
}
|
||||
if err := query.Scan(&result).Error; err != nil {
|
||||
return 0, fmt.Errorf("failed to get sum by category: %w", err)
|
||||
}
|
||||
return result.Total, nil
|
||||
}
|
||||
|
||||
// GetRecentTransactions retrieves the most recent transactions
|
||||
func (r *TransactionRepository) GetRecentTransactions(userID uint, limit int) ([]models.Transaction, error) {
|
||||
var transactions []models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).
|
||||
Order("transaction_date DESC, created_at DESC").
|
||||
Limit(limit).
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
Preload("Tags").
|
||||
Find(&transactions).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get recent transactions: %w", err)
|
||||
}
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// GetRelatedTransactions retrieves all related transactions for a given transaction ID
|
||||
// For an expense transaction: returns its refund income and/or reimbursement income if they exist
|
||||
// For a refund/reimbursement income: returns the original expense transaction
|
||||
// Feature: accounting-feature-upgrade
|
||||
// Validates: Requirements 8.21, 8.22
|
||||
func (r *TransactionRepository) GetRelatedTransactions(userID uint, id uint) ([]models.Transaction, error) {
|
||||
var relatedTransactions []models.Transaction
|
||||
|
||||
// First, get the transaction itself to determine its type
|
||||
var transaction models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).First(&transaction, id).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTransactionNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get transaction: %w", err)
|
||||
}
|
||||
|
||||
// Case 1: If this is an expense transaction, find its refund and reimbursement income records
|
||||
if transaction.Type == models.TransactionTypeExpense {
|
||||
// Find refund income if exists
|
||||
if transaction.RefundIncomeID != nil {
|
||||
var refundIncome models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
First(&refundIncome, *transaction.RefundIncomeID).Error; err == nil {
|
||||
relatedTransactions = append(relatedTransactions, refundIncome)
|
||||
}
|
||||
}
|
||||
|
||||
// Find reimbursement income if exists
|
||||
if transaction.ReimbursementIncomeID != nil {
|
||||
var reimbursementIncome models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
First(&reimbursementIncome, *transaction.ReimbursementIncomeID).Error; err == nil {
|
||||
relatedTransactions = append(relatedTransactions, reimbursementIncome)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: If this is a refund or reimbursement income, find the original expense transaction
|
||||
if transaction.OriginalTransactionID != nil {
|
||||
var originalTransaction models.Transaction
|
||||
if err := r.db.Where("user_id = ?", userID).
|
||||
Preload("Category").
|
||||
Preload("Account").
|
||||
First(&originalTransaction, *transaction.OriginalTransactionID).Error; err == nil {
|
||||
relatedTransactions = append(relatedTransactions, originalTransaction)
|
||||
}
|
||||
}
|
||||
|
||||
return relatedTransactions, nil
|
||||
}
|
||||
Reference in New Issue
Block a user