Files
Novault-backend/internal/service/recurring_transaction_service.go
2026-01-25 21:59:00 +08:00

548 lines
19 KiB
Go

package service
import (
"errors"
"fmt"
"time"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// RecurringTransactionService handles business logic for recurring transactions
type RecurringTransactionService struct {
recurringRepo *repository.RecurringTransactionRepository
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
allocationRuleRepo *repository.AllocationRuleRepository
recordRepo *repository.AllocationRecordRepository
piggyBankRepo *repository.PiggyBankRepository
db *gorm.DB
}
// NewRecurringTransactionService creates a new RecurringTransactionService instance
func NewRecurringTransactionService(
recurringRepo *repository.RecurringTransactionRepository,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
categoryRepo *repository.CategoryRepository,
allocationRuleRepo *repository.AllocationRuleRepository,
recordRepo *repository.AllocationRecordRepository,
piggyBankRepo *repository.PiggyBankRepository,
db *gorm.DB,
) *RecurringTransactionService {
return &RecurringTransactionService{
recurringRepo: recurringRepo,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
categoryRepo: categoryRepo,
allocationRuleRepo: allocationRuleRepo,
recordRepo: recordRepo,
piggyBankRepo: piggyBankRepo,
db: db,
}
}
// CreateRecurringTransactionRequest represents the request to create a recurring transaction
type CreateRecurringTransactionRequest struct {
UserID uint `json:"user_id"`
Amount float64 `json:"amount" binding:"required,gt=0"`
Type models.TransactionType `json:"type" binding:"required,oneof=income expense"`
CategoryID uint `json:"category_id" binding:"required"`
AccountID uint `json:"account_id" binding:"required"`
Currency models.Currency `json:"currency" binding:"required"`
Note string `json:"note"`
Frequency models.FrequencyType `json:"frequency" binding:"required,oneof=daily weekly monthly yearly"`
StartDate time.Time `json:"start_date" binding:"required"`
EndDate *time.Time `json:"end_date"`
}
// UpdateRecurringTransactionRequest represents the request to update a recurring transaction
type UpdateRecurringTransactionRequest struct {
Amount *float64 `json:"amount" binding:"omitempty,gt=0"`
Type *models.TransactionType `json:"type" binding:"omitempty,oneof=income expense"`
CategoryID *uint `json:"category_id"`
AccountID *uint `json:"account_id"`
Currency *models.Currency `json:"currency"`
Note *string `json:"note"`
Frequency *models.FrequencyType `json:"frequency" binding:"omitempty,oneof=daily weekly monthly yearly"`
StartDate *time.Time `json:"start_date"`
EndDate *time.Time `json:"end_date"`
ClearEndDate bool `json:"clear_end_date"` // 璁句负true鏃舵竻闄ょ粨鏉熸棩鏈?
IsActive *bool `json:"is_active"`
}
// Create creates a new recurring transaction
func (s *RecurringTransactionService) Create(req CreateRecurringTransactionRequest) (*models.RecurringTransaction, error) {
// Validate account exists
account, err := s.accountRepo.GetByID(req.UserID, req.AccountID)
if err != nil {
if errors.Is(err, repository.ErrAccountNotFound) {
return nil, fmt.Errorf("account not found")
}
return nil, fmt.Errorf("failed to validate account: %w", err)
}
// Validate category exists
_, err = s.categoryRepo.GetByID(req.UserID, req.CategoryID)
if err != nil {
if errors.Is(err, repository.ErrCategoryNotFound) {
return nil, fmt.Errorf("category not found")
}
return nil, fmt.Errorf("failed to validate category: %w", err)
}
// Validate currency matches account currency
if req.Currency != account.Currency {
return nil, fmt.Errorf("currency mismatch: transaction currency %s does not match account currency %s", req.Currency, account.Currency)
}
// Validate end date is after start date
if req.EndDate != nil && req.EndDate.Before(req.StartDate) {
return nil, fmt.Errorf("end date must be after start date")
}
// Calculate next occurrence (first occurrence is the start date)
nextOccurrence := req.StartDate
recurringTransaction := &models.RecurringTransaction{
UserID: req.UserID,
Amount: req.Amount,
Type: req.Type,
CategoryID: req.CategoryID,
AccountID: req.AccountID,
Currency: req.Currency,
Note: req.Note,
Frequency: req.Frequency,
StartDate: req.StartDate,
EndDate: req.EndDate,
NextOccurrence: nextOccurrence,
IsActive: true,
}
if err := s.recurringRepo.Create(recurringTransaction); err != nil {
return nil, fmt.Errorf("failed to create recurring transaction: %w", err)
}
return recurringTransaction, nil
}
// GetByID retrieves a recurring transaction by its ID and verifies ownership
func (s *RecurringTransactionService) GetByID(userID, id uint) (*models.RecurringTransaction, error) {
recurringTransaction, err := s.recurringRepo.GetByIDWithRelations(userID, id)
if err != nil {
return nil, err
}
if recurringTransaction.UserID != userID {
return nil, repository.ErrRecurringTransactionNotFound
}
return recurringTransaction, nil
}
// Update updates an existing recurring transaction after verifying ownership
func (s *RecurringTransactionService) Update(userID, id uint, req UpdateRecurringTransactionRequest) (*models.RecurringTransaction, error) {
// Get existing recurring transaction
recurringTransaction, err := s.recurringRepo.GetByID(userID, id)
if err != nil {
return nil, err
}
if recurringTransaction.UserID != userID {
return nil, repository.ErrRecurringTransactionNotFound
}
// Update fields if provided
if req.Amount != nil {
recurringTransaction.Amount = *req.Amount
}
if req.Type != nil {
recurringTransaction.Type = *req.Type
}
if req.CategoryID != nil {
// Validate category exists
_, err := s.categoryRepo.GetByID(userID, *req.CategoryID)
if err != nil {
if errors.Is(err, repository.ErrCategoryNotFound) {
return nil, fmt.Errorf("category not found")
}
return nil, fmt.Errorf("failed to validate category: %w", err)
}
recurringTransaction.CategoryID = *req.CategoryID
}
if req.AccountID != nil {
// Validate account exists
account, err := s.accountRepo.GetByID(userID, *req.AccountID)
if err != nil {
if errors.Is(err, repository.ErrAccountNotFound) {
return nil, fmt.Errorf("account not found")
}
return nil, fmt.Errorf("failed to validate account: %w", err)
}
// Validate currency matches if currency is not being updated
if req.Currency == nil && recurringTransaction.Currency != account.Currency {
return nil, fmt.Errorf("currency mismatch: transaction currency %s does not match account currency %s", recurringTransaction.Currency, account.Currency)
}
recurringTransaction.AccountID = *req.AccountID
}
if req.Currency != nil {
// Validate currency matches account
account, err := s.accountRepo.GetByID(userID, recurringTransaction.AccountID)
if err != nil {
return nil, fmt.Errorf("failed to validate account: %w", err)
}
if *req.Currency != account.Currency {
return nil, fmt.Errorf("currency mismatch: transaction currency %s does not match account currency %s", *req.Currency, account.Currency)
}
recurringTransaction.Currency = *req.Currency
}
if req.Note != nil {
recurringTransaction.Note = *req.Note
}
if req.Frequency != nil {
recurringTransaction.Frequency = *req.Frequency
// Recalculate next occurrence with new frequency
recurringTransaction.NextOccurrence = s.CalculateNextOccurrence(recurringTransaction.NextOccurrence, *req.Frequency)
}
if req.StartDate != nil {
recurringTransaction.StartDate = *req.StartDate
}
if req.ClearEndDate {
// 娓呴櫎缁撴潫鏃ユ湡
recurringTransaction.EndDate = nil
} else if req.EndDate != nil {
// 楠岃瘉缁撴潫鏃ユ湡蹇呴』鍦ㄥ紑濮嬫棩鏈熶箣鍚?
if req.EndDate.Before(recurringTransaction.StartDate) {
return nil, fmt.Errorf("end date must be after start date")
}
recurringTransaction.EndDate = req.EndDate
}
if req.IsActive != nil {
recurringTransaction.IsActive = *req.IsActive
}
if err := s.recurringRepo.Update(recurringTransaction); err != nil {
return nil, fmt.Errorf("failed to update recurring transaction: %w", err)
}
return recurringTransaction, nil
}
// Delete deletes a recurring transaction after verifying ownership
func (s *RecurringTransactionService) Delete(userID, id uint) error {
recurringTransaction, err := s.recurringRepo.GetByID(userID, id)
if err != nil {
return err
}
if recurringTransaction.UserID != userID {
return repository.ErrRecurringTransactionNotFound
}
return s.recurringRepo.Delete(userID, id)
}
// List retrieves all recurring transactions for a user
func (s *RecurringTransactionService) List(userID uint) ([]models.RecurringTransaction, error) {
return s.recurringRepo.List(userID)
}
// GetActive retrieves all active recurring transactions for a user
func (s *RecurringTransactionService) GetActive(userID uint) ([]models.RecurringTransaction, error) {
return s.recurringRepo.GetActive(userID)
}
// CalculateNextOccurrence calculates the next occurrence date based on the current date and frequency
func (s *RecurringTransactionService) CalculateNextOccurrence(currentDate time.Time, frequency models.FrequencyType) time.Time {
switch frequency {
case models.FrequencyDaily:
return currentDate.AddDate(0, 0, 1)
case models.FrequencyWeekly:
return currentDate.AddDate(0, 0, 7)
case models.FrequencyMonthly:
return currentDate.AddDate(0, 1, 0)
case models.FrequencyYearly:
return currentDate.AddDate(1, 0, 0)
default:
// Default to daily if unknown frequency
return currentDate.AddDate(0, 0, 1)
}
}
// ProcessDueTransactionsResult represents the result of processing due transactions
type ProcessDueTransactionsResult struct {
Transactions []models.Transaction `json:"transactions"`
Allocations []AllocationResult `json:"allocations,omitempty"`
}
// ProcessDueTransactions processes all due recurring transactions for a user and generates actual transactions
// For income transactions, it also triggers matching allocation rules
func (s *RecurringTransactionService) ProcessDueTransactions(userID uint, now time.Time) (*ProcessDueTransactionsResult, error) {
// Get all due recurring transactions
dueRecurringTransactions, err := s.recurringRepo.GetDueTransactions(userID, now)
if err != nil {
return nil, fmt.Errorf("failed to get due recurring transactions: %w", err)
}
result := &ProcessDueTransactionsResult{
Transactions: []models.Transaction{},
Allocations: []AllocationResult{},
}
for _, recurringTxn := range dueRecurringTransactions {
// Check if the recurring transaction has ended
if recurringTxn.EndDate != nil && recurringTxn.NextOccurrence.After(*recurringTxn.EndDate) {
// Deactivate the recurring transaction
recurringTxn.IsActive = false
if err := s.recurringRepo.Update(&recurringTxn); err != nil {
return nil, fmt.Errorf("failed to deactivate recurring transaction %d: %w", recurringTxn.ID, err)
}
continue
}
// Start a database transaction for each recurring transaction
tx := s.db.Begin()
if tx.Error != nil {
return nil, fmt.Errorf("failed to begin transaction: %w", tx.Error)
}
// Generate the transaction
transaction := models.Transaction{
UserID: recurringTxn.UserID,
Amount: recurringTxn.Amount,
Type: recurringTxn.Type,
CategoryID: recurringTxn.CategoryID,
AccountID: recurringTxn.AccountID,
Currency: recurringTxn.Currency,
TransactionDate: recurringTxn.NextOccurrence,
Note: recurringTxn.Note,
RecurringID: &recurringTxn.ID,
}
// Create the transaction
if err := tx.Create(&transaction).Error; err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to create transaction from recurring transaction %d: %w", recurringTxn.ID, err)
}
// Update account balance
var account models.Account
if err := tx.First(&account, recurringTxn.AccountID).Error; err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to get account %d: %w", recurringTxn.AccountID, err)
}
switch recurringTxn.Type {
case models.TransactionTypeIncome:
account.Balance += recurringTxn.Amount
case models.TransactionTypeExpense:
account.Balance -= recurringTxn.Amount
}
if err := tx.Save(&account).Error; err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to update account balance: %w", err)
}
// For income transactions, check and apply allocation rules
if recurringTxn.Type == models.TransactionTypeIncome && s.allocationRuleRepo != nil {
allocationResult, err := s.applyAllocationRulesForIncome(userID, tx, recurringTxn.AccountID, recurringTxn.Amount)
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to apply allocation rules: %w", err)
}
if allocationResult != nil {
result.Allocations = append(result.Allocations, *allocationResult)
}
}
// Calculate and update next occurrence
nextOccurrence := s.CalculateNextOccurrence(recurringTxn.NextOccurrence, recurringTxn.Frequency)
recurringTxn.NextOccurrence = nextOccurrence
// Check if the next occurrence is beyond the end date
if recurringTxn.EndDate != nil && nextOccurrence.After(*recurringTxn.EndDate) {
recurringTxn.IsActive = false
}
if err := tx.Save(&recurringTxn).Error; err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to update recurring transaction %d: %w", recurringTxn.ID, err)
}
// Commit the transaction
if err := tx.Commit().Error; err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
result.Transactions = append(result.Transactions, transaction)
}
return result, nil
}
// applyAllocationRulesForIncome applies matching allocation rules for income transactions
func (s *RecurringTransactionService) applyAllocationRulesForIncome(userID uint, tx *gorm.DB, accountID uint, amount float64) (*AllocationResult, error) {
// Get active allocation rules that match income trigger and source account
rules, err := s.allocationRuleRepo.GetActiveByTriggerTypeAndAccount(userID, models.TriggerTypeIncome, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get allocation rules: %w", err)
}
if len(rules) == 0 {
return nil, nil // No matching rules
}
// Apply the first matching rule (can be extended to apply multiple rules)
rule := rules[0]
// Calculate allocations
result := &AllocationResult{
RuleID: rule.ID,
RuleName: rule.Name,
TotalAmount: amount,
Allocations: []AllocationDetail{},
}
totalAllocated := 0.0
// Process each target
for _, target := range rule.Targets {
var allocatedAmount float64
// Calculate allocation amount
if target.Percentage != nil {
allocatedAmount = amount * (*target.Percentage / 100.0)
} else if target.FixedAmount != nil {
allocatedAmount = *target.FixedAmount
// Ensure we don't allocate more than available
if allocatedAmount > amount-totalAllocated {
allocatedAmount = amount - totalAllocated
}
} else {
continue // Skip invalid target
}
// Round to 2 decimal places
allocatedAmount = float64(int(allocatedAmount*100+0.5)) / 100
if allocatedAmount <= 0 {
continue
}
// Get target name and apply allocation
targetName := ""
switch target.TargetType {
case models.TargetTypeAccount:
var targetAccount models.Account
if err := tx.First(&targetAccount, target.TargetID).Error; err != nil {
return nil, fmt.Errorf("failed to get target account: %w", err)
}
targetName = targetAccount.Name
// Add to target account
targetAccount.Balance += allocatedAmount
if err := tx.Save(&targetAccount).Error; err != nil {
return nil, fmt.Errorf("failed to update target account balance: %w", err)
}
// Deduct from source account
var sourceAccount models.Account
if err := tx.First(&sourceAccount, accountID).Error; err != nil {
return nil, fmt.Errorf("failed to get source account: %w", err)
}
sourceAccount.Balance -= allocatedAmount
if err := tx.Save(&sourceAccount).Error; err != nil {
return nil, fmt.Errorf("failed to update source account balance: %w", err)
}
case models.TargetTypePiggyBank:
var piggyBank models.PiggyBank
if err := tx.First(&piggyBank, target.TargetID).Error; err != nil {
return nil, fmt.Errorf("failed to get target piggy bank: %w", err)
}
targetName = piggyBank.Name
// Add to piggy bank
piggyBank.CurrentAmount += allocatedAmount
if err := tx.Save(&piggyBank).Error; err != nil {
return nil, fmt.Errorf("failed to update piggy bank balance: %w", err)
}
// Deduct from source account
var sourceAccount models.Account
if err := tx.First(&sourceAccount, accountID).Error; err != nil {
return nil, fmt.Errorf("failed to get source account: %w", err)
}
sourceAccount.Balance -= allocatedAmount
if err := tx.Save(&sourceAccount).Error; err != nil {
return nil, fmt.Errorf("failed to update source account balance: %w", err)
}
default:
continue // Skip invalid target type
}
// Add to result
result.Allocations = append(result.Allocations, AllocationDetail{
TargetType: target.TargetType,
TargetID: target.TargetID,
TargetName: targetName,
Amount: allocatedAmount,
Percentage: target.Percentage,
FixedAmount: target.FixedAmount,
})
totalAllocated += allocatedAmount
}
result.AllocatedAmount = totalAllocated
result.Remaining = amount - totalAllocated
// Create allocation record
if totalAllocated > 0 {
allocationRecord := &models.AllocationRecord{
UserID: userID,
RuleID: rule.ID,
RuleName: rule.Name,
SourceAccountID: accountID,
TotalAmount: amount,
AllocatedAmount: totalAllocated,
RemainingAmount: result.Remaining,
Note: fmt.Sprintf("鍛ㄦ湡鎬ф敹鍏ヨ嚜鍔ㄥ垎閰?(瑙勫垯: %s)", rule.Name),
}
if err := tx.Create(allocationRecord).Error; err != nil {
return nil, fmt.Errorf("failed to create allocation record: %w", err)
}
// Save allocation record details
for _, allocation := range result.Allocations {
detail := &models.AllocationRecordDetail{
RecordID: allocationRecord.ID,
TargetType: allocation.TargetType,
TargetID: allocation.TargetID,
TargetName: allocation.TargetName,
Amount: allocation.Amount,
Percentage: allocation.Percentage,
FixedAmount: allocation.FixedAmount,
}
if err := tx.Create(detail).Error; err != nil {
return nil, fmt.Errorf("failed to create allocation record detail: %w", err)
}
}
}
return result, nil
}
// GetByAccountID retrieves all recurring transactions for a specific account
func (s *RecurringTransactionService) GetByAccountID(userID, accountID uint) ([]models.RecurringTransaction, error) {
return s.recurringRepo.GetByAccountID(userID, accountID)
}
// GetByCategoryID retrieves all recurring transactions for a specific category
func (s *RecurringTransactionService) GetByCategoryID(userID, categoryID uint) ([]models.RecurringTransaction, error) {
return s.recurringRepo.GetByCategoryID(userID, categoryID)
}