394 lines
11 KiB
Go
394 lines
11 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/csv"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"accounting-app/internal/models"
|
||
"accounting-app/internal/repository"
|
||
)
|
||
|
||
// Import service errors
|
||
var (
|
||
ErrInvalidFileFormat = errors.New("invalid file format")
|
||
ErrEmptyFile = errors.New("file is empty")
|
||
ErrInvalidHeader = errors.New("invalid or missing header row")
|
||
ErrInvalidRowData = errors.New("invalid row data")
|
||
)
|
||
|
||
// ImportResult represents the result of a batch import operation
|
||
type ImportResult struct {
|
||
TotalRows int `json:"total_rows"`
|
||
SuccessCount int `json:"success_count"`
|
||
FailedCount int `json:"failed_count"`
|
||
Errors []ImportError `json:"errors,omitempty"`
|
||
Transactions []uint `json:"transaction_ids,omitempty"`
|
||
}
|
||
|
||
// ImportError represents an error that occurred during import
|
||
type ImportError struct {
|
||
Row int `json:"row"`
|
||
Column string `json:"column,omitempty"`
|
||
Message string `json:"message"`
|
||
}
|
||
|
||
// TransactionImportRow represents a single row of transaction data to import
|
||
type TransactionImportRow struct {
|
||
Date string `json:"date"` // Required: YYYY-MM-DD format
|
||
Amount float64 `json:"amount"` // Required: positive number
|
||
Type string `json:"type"` // Required: income/expense/transfer
|
||
Category string `json:"category"` // Required: category name
|
||
Account string `json:"account"` // Required: account name
|
||
Note string `json:"note"` // Optional
|
||
Currency string `json:"currency"` // Optional: defaults to CNY
|
||
ToAccount string `json:"to_account"` // Optional: for transfers
|
||
}
|
||
|
||
// ImportService handles batch import of transactions
|
||
type ImportService struct {
|
||
transactionRepo *repository.TransactionRepository
|
||
categoryRepo *repository.CategoryRepository
|
||
accountRepo *repository.AccountRepository
|
||
}
|
||
|
||
// NewImportService creates a new ImportService instance
|
||
func NewImportService(
|
||
transactionRepo *repository.TransactionRepository,
|
||
categoryRepo *repository.CategoryRepository,
|
||
accountRepo *repository.AccountRepository,
|
||
) *ImportService {
|
||
return &ImportService{
|
||
transactionRepo: transactionRepo,
|
||
categoryRepo: categoryRepo,
|
||
accountRepo: accountRepo,
|
||
}
|
||
}
|
||
|
||
// ImportFromCSV imports transactions from a CSV file
|
||
// Expected CSV format: date,amount,type,category,account,note,currency,to_account
|
||
func (s *ImportService) ImportFromCSV(userID uint, reader io.Reader) (*ImportResult, error) {
|
||
csvReader := csv.NewReader(reader)
|
||
csvReader.TrimLeadingSpace = true
|
||
|
||
// Read header row
|
||
header, err := csvReader.Read()
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
return nil, ErrEmptyFile
|
||
}
|
||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||
}
|
||
|
||
// Validate and map header columns
|
||
columnMap, err := s.parseHeader(header)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
result := &ImportResult{
|
||
Errors: make([]ImportError, 0),
|
||
Transactions: make([]uint, 0),
|
||
}
|
||
|
||
rowNum := 1 // Start from 1 (after header)
|
||
for {
|
||
record, err := csvReader.Read()
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
if err != nil {
|
||
result.Errors = append(result.Errors, ImportError{
|
||
Row: rowNum,
|
||
Message: fmt.Sprintf("failed to read row: %v", err),
|
||
})
|
||
result.FailedCount++
|
||
rowNum++
|
||
continue
|
||
}
|
||
|
||
result.TotalRows++
|
||
rowNum++
|
||
|
||
// Parse row data
|
||
row, parseErr := s.parseRow(record, columnMap, rowNum)
|
||
if parseErr != nil {
|
||
result.Errors = append(result.Errors, *parseErr)
|
||
result.FailedCount++
|
||
continue
|
||
}
|
||
|
||
// Create transaction
|
||
txID, createErr := s.createTransaction(userID, row, rowNum)
|
||
if createErr != nil {
|
||
result.Errors = append(result.Errors, *createErr)
|
||
result.FailedCount++
|
||
continue
|
||
}
|
||
|
||
result.SuccessCount++
|
||
result.Transactions = append(result.Transactions, txID)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// parseHeader validates and maps CSV header columns
|
||
func (s *ImportService) parseHeader(header []string) (map[string]int, error) {
|
||
columnMap := make(map[string]int)
|
||
requiredColumns := []string{"date", "amount", "type", "category", "account"}
|
||
|
||
for i, col := range header {
|
||
normalizedCol := strings.ToLower(strings.TrimSpace(col))
|
||
columnMap[normalizedCol] = i
|
||
}
|
||
|
||
// Check required columns
|
||
for _, required := range requiredColumns {
|
||
if _, ok := columnMap[required]; !ok {
|
||
return nil, fmt.Errorf("%w: missing required column '%s'", ErrInvalidHeader, required)
|
||
}
|
||
}
|
||
|
||
return columnMap, nil
|
||
}
|
||
|
||
// parseRow parses a CSV row into TransactionImportRow
|
||
func (s *ImportService) parseRow(record []string, columnMap map[string]int, rowNum int) (*TransactionImportRow, *ImportError) {
|
||
getValue := func(col string) string {
|
||
if idx, ok := columnMap[col]; ok && idx < len(record) {
|
||
return strings.TrimSpace(record[idx])
|
||
}
|
||
return ""
|
||
}
|
||
|
||
row := &TransactionImportRow{
|
||
Date: getValue("date"),
|
||
Type: getValue("type"),
|
||
Category: getValue("category"),
|
||
Account: getValue("account"),
|
||
Note: getValue("note"),
|
||
Currency: getValue("currency"),
|
||
ToAccount: getValue("to_account"),
|
||
}
|
||
|
||
// Parse amount
|
||
amountStr := getValue("amount")
|
||
if amountStr == "" {
|
||
return nil, &ImportError{Row: rowNum, Column: "amount", Message: "amount is required"}
|
||
}
|
||
amount, err := strconv.ParseFloat(amountStr, 64)
|
||
if err != nil {
|
||
return nil, &ImportError{Row: rowNum, Column: "amount", Message: "invalid amount format"}
|
||
}
|
||
row.Amount = amount
|
||
|
||
// Validate required fields
|
||
if row.Date == "" {
|
||
return nil, &ImportError{Row: rowNum, Column: "date", Message: "date is required"}
|
||
}
|
||
if row.Type == "" {
|
||
return nil, &ImportError{Row: rowNum, Column: "type", Message: "type is required"}
|
||
}
|
||
if row.Category == "" {
|
||
return nil, &ImportError{Row: rowNum, Column: "category", Message: "category is required"}
|
||
}
|
||
if row.Account == "" {
|
||
return nil, &ImportError{Row: rowNum, Column: "account", Message: "account is required"}
|
||
}
|
||
|
||
return row, nil
|
||
}
|
||
|
||
// createTransaction creates a transaction from import row data
|
||
func (s *ImportService) createTransaction(userID uint, row *TransactionImportRow, rowNum int) (uint, *ImportError) {
|
||
// Parse date
|
||
date, err := time.Parse("2006-01-02", row.Date)
|
||
if err != nil {
|
||
// Try alternative formats
|
||
date, err = time.Parse("2006/01/02", row.Date)
|
||
if err != nil {
|
||
return 0, &ImportError{Row: rowNum, Column: "date", Message: "invalid date format, expected YYYY-MM-DD"}
|
||
}
|
||
}
|
||
|
||
// Parse transaction type
|
||
txType, err := s.parseTransactionType(row.Type)
|
||
if err != nil {
|
||
return 0, &ImportError{Row: rowNum, Column: "type", Message: err.Error()}
|
||
}
|
||
|
||
// Find category by name
|
||
category, err := s.categoryRepo.GetByName(userID, row.Category)
|
||
if err != nil {
|
||
return 0, &ImportError{Row: rowNum, Column: "category", Message: fmt.Sprintf("category '%s' not found", row.Category)}
|
||
}
|
||
|
||
// Find account by name
|
||
account, err := s.accountRepo.GetByName(userID, row.Account)
|
||
if err != nil {
|
||
return 0, &ImportError{Row: rowNum, Column: "account", Message: fmt.Sprintf("account '%s' not found", row.Account)}
|
||
}
|
||
|
||
// Parse currency
|
||
currency := models.CurrencyCNY
|
||
if row.Currency != "" {
|
||
currency = models.Currency(strings.ToUpper(row.Currency))
|
||
}
|
||
|
||
// Create transaction
|
||
tx := &models.Transaction{
|
||
UserID: userID,
|
||
Amount: row.Amount,
|
||
Type: txType,
|
||
CategoryID: category.ID,
|
||
AccountID: account.ID,
|
||
Currency: currency,
|
||
TransactionDate: date,
|
||
Note: row.Note,
|
||
}
|
||
|
||
// Handle transfer transactions
|
||
if txType == models.TransactionTypeTransfer && row.ToAccount != "" {
|
||
toAccount, err := s.accountRepo.GetByName(userID, row.ToAccount)
|
||
if err != nil {
|
||
return 0, &ImportError{Row: rowNum, Column: "to_account", Message: fmt.Sprintf("to_account '%s' not found", row.ToAccount)}
|
||
}
|
||
tx.ToAccountID = &toAccount.ID
|
||
}
|
||
|
||
// Save transaction
|
||
if err := s.transactionRepo.Create(tx); err != nil {
|
||
return 0, &ImportError{Row: rowNum, Message: fmt.Sprintf("failed to create transaction: %v", err)}
|
||
}
|
||
|
||
return tx.ID, nil
|
||
}
|
||
|
||
// parseTransactionType converts string to TransactionType
|
||
func (s *ImportService) parseTransactionType(typeStr string) (models.TransactionType, error) {
|
||
switch strings.ToLower(typeStr) {
|
||
case "income", "收入":
|
||
return models.TransactionTypeIncome, nil
|
||
case "expense", "支出":
|
||
return models.TransactionTypeExpense, nil
|
||
case "transfer", "转账":
|
||
return models.TransactionTypeTransfer, nil
|
||
default:
|
||
return "", fmt.Errorf("invalid transaction type '%s', expected income/expense/transfer", typeStr)
|
||
}
|
||
}
|
||
|
||
// GenerateCSVTemplate generates a CSV template for import
|
||
func (s *ImportService) GenerateCSVTemplate() string {
|
||
header := "date,amount,type,category,account,note,currency,to_account\n"
|
||
example := "2024-01-15,100.00,expense,餐饮,现金,午餐,CNY,\n"
|
||
example += "2024-01-16,5000.00,income,工资,银行<E993B6>?月薪,CNY,\n"
|
||
example += "2024-01-17,200.00,transfer,转账,银行<E993B6>?转到支付<E694AF>?CNY,支付宝\n"
|
||
return header + example
|
||
}
|
||
|
||
// ValidateImportData validates import data without creating transactions
|
||
func (s *ImportService) ValidateImportData(userID uint, reader io.Reader) (*ImportResult, error) {
|
||
csvReader := csv.NewReader(reader)
|
||
csvReader.TrimLeadingSpace = true
|
||
|
||
// Read header row
|
||
header, err := csvReader.Read()
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
return nil, ErrEmptyFile
|
||
}
|
||
return nil, fmt.Errorf("failed to read header: %w", err)
|
||
}
|
||
|
||
// Validate and map header columns
|
||
columnMap, err := s.parseHeader(header)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
result := &ImportResult{
|
||
Errors: make([]ImportError, 0),
|
||
}
|
||
|
||
rowNum := 1
|
||
for {
|
||
record, err := csvReader.Read()
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
if err != nil {
|
||
result.Errors = append(result.Errors, ImportError{
|
||
Row: rowNum,
|
||
Message: fmt.Sprintf("failed to read row: %v", err),
|
||
})
|
||
result.FailedCount++
|
||
rowNum++
|
||
continue
|
||
}
|
||
|
||
result.TotalRows++
|
||
rowNum++
|
||
|
||
// Parse and validate row data
|
||
row, parseErr := s.parseRow(record, columnMap, rowNum)
|
||
if parseErr != nil {
|
||
result.Errors = append(result.Errors, *parseErr)
|
||
result.FailedCount++
|
||
continue
|
||
}
|
||
|
||
// Validate references exist
|
||
if validateErr := s.validateRow(userID, row, rowNum); validateErr != nil {
|
||
result.Errors = append(result.Errors, *validateErr)
|
||
result.FailedCount++
|
||
continue
|
||
}
|
||
|
||
result.SuccessCount++
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// validateRow validates that all references in a row exist
|
||
func (s *ImportService) validateRow(userID uint, row *TransactionImportRow, rowNum int) *ImportError {
|
||
// Validate date format
|
||
_, err := time.Parse("2006-01-02", row.Date)
|
||
if err != nil {
|
||
_, err = time.Parse("2006/01/02", row.Date)
|
||
if err != nil {
|
||
return &ImportError{Row: rowNum, Column: "date", Message: "invalid date format"}
|
||
}
|
||
}
|
||
|
||
// Validate transaction type
|
||
if _, err := s.parseTransactionType(row.Type); err != nil {
|
||
return &ImportError{Row: rowNum, Column: "type", Message: err.Error()}
|
||
}
|
||
|
||
// Validate category exists
|
||
if _, err := s.categoryRepo.GetByName(userID, row.Category); err != nil {
|
||
return &ImportError{Row: rowNum, Column: "category", Message: fmt.Sprintf("category '%s' not found", row.Category)}
|
||
}
|
||
|
||
// Validate account exists
|
||
if _, err := s.accountRepo.GetByName(userID, row.Account); err != nil {
|
||
return &ImportError{Row: rowNum, Column: "account", Message: fmt.Sprintf("account '%s' not found", row.Account)}
|
||
}
|
||
|
||
// Validate to_account for transfers
|
||
if strings.ToLower(row.Type) == "transfer" && row.ToAccount != "" {
|
||
if _, err := s.accountRepo.GetByName(userID, row.ToAccount); err != nil {
|
||
return &ImportError{Row: rowNum, Column: "to_account", Message: fmt.Sprintf("to_account '%s' not found", row.ToAccount)}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|