package service import ( "encoding/csv" "errors" "fmt" "io" "strconv" "strings" "time" "accounting-app/internal/models" "accounting-app/internal/repository" ) // Import service errors var ( ErrInvalidFileFormat = errors.New("invalid file format") ErrEmptyFile = errors.New("file is empty") ErrInvalidHeader = errors.New("invalid or missing header row") ErrInvalidRowData = errors.New("invalid row data") ) // ImportResult represents the result of a batch import operation type ImportResult struct { TotalRows int `json:"total_rows"` SuccessCount int `json:"success_count"` FailedCount int `json:"failed_count"` Errors []ImportError `json:"errors,omitempty"` Transactions []uint `json:"transaction_ids,omitempty"` } // ImportError represents an error that occurred during import type ImportError struct { Row int `json:"row"` Column string `json:"column,omitempty"` Message string `json:"message"` } // TransactionImportRow represents a single row of transaction data to import type TransactionImportRow struct { Date string `json:"date"` // Required: YYYY-MM-DD format Amount float64 `json:"amount"` // Required: positive number Type string `json:"type"` // Required: income/expense/transfer Category string `json:"category"` // Required: category name Account string `json:"account"` // Required: account name Note string `json:"note"` // Optional Currency string `json:"currency"` // Optional: defaults to CNY ToAccount string `json:"to_account"` // Optional: for transfers } // ImportService handles batch import of transactions type ImportService struct { transactionRepo *repository.TransactionRepository categoryRepo *repository.CategoryRepository accountRepo *repository.AccountRepository } // NewImportService creates a new ImportService instance func NewImportService( transactionRepo *repository.TransactionRepository, categoryRepo *repository.CategoryRepository, accountRepo *repository.AccountRepository, ) *ImportService { return &ImportService{ transactionRepo: transactionRepo, categoryRepo: categoryRepo, accountRepo: accountRepo, } } // ImportFromCSV imports transactions from a CSV file // Expected CSV format: date,amount,type,category,account,note,currency,to_account func (s *ImportService) ImportFromCSV(userID uint, reader io.Reader) (*ImportResult, error) { csvReader := csv.NewReader(reader) csvReader.TrimLeadingSpace = true // Read header row header, err := csvReader.Read() if err != nil { if err == io.EOF { return nil, ErrEmptyFile } return nil, fmt.Errorf("failed to read header: %w", err) } // Validate and map header columns columnMap, err := s.parseHeader(header) if err != nil { return nil, err } result := &ImportResult{ Errors: make([]ImportError, 0), Transactions: make([]uint, 0), } rowNum := 1 // Start from 1 (after header) for { record, err := csvReader.Read() if err == io.EOF { break } if err != nil { result.Errors = append(result.Errors, ImportError{ Row: rowNum, Message: fmt.Sprintf("failed to read row: %v", err), }) result.FailedCount++ rowNum++ continue } result.TotalRows++ rowNum++ // Parse row data row, parseErr := s.parseRow(record, columnMap, rowNum) if parseErr != nil { result.Errors = append(result.Errors, *parseErr) result.FailedCount++ continue } // Create transaction txID, createErr := s.createTransaction(userID, row, rowNum) if createErr != nil { result.Errors = append(result.Errors, *createErr) result.FailedCount++ continue } result.SuccessCount++ result.Transactions = append(result.Transactions, txID) } return result, nil } // parseHeader validates and maps CSV header columns func (s *ImportService) parseHeader(header []string) (map[string]int, error) { columnMap := make(map[string]int) requiredColumns := []string{"date", "amount", "type", "category", "account"} for i, col := range header { normalizedCol := strings.ToLower(strings.TrimSpace(col)) columnMap[normalizedCol] = i } // Check required columns for _, required := range requiredColumns { if _, ok := columnMap[required]; !ok { return nil, fmt.Errorf("%w: missing required column '%s'", ErrInvalidHeader, required) } } return columnMap, nil } // parseRow parses a CSV row into TransactionImportRow func (s *ImportService) parseRow(record []string, columnMap map[string]int, rowNum int) (*TransactionImportRow, *ImportError) { getValue := func(col string) string { if idx, ok := columnMap[col]; ok && idx < len(record) { return strings.TrimSpace(record[idx]) } return "" } row := &TransactionImportRow{ Date: getValue("date"), Type: getValue("type"), Category: getValue("category"), Account: getValue("account"), Note: getValue("note"), Currency: getValue("currency"), ToAccount: getValue("to_account"), } // Parse amount amountStr := getValue("amount") if amountStr == "" { return nil, &ImportError{Row: rowNum, Column: "amount", Message: "amount is required"} } amount, err := strconv.ParseFloat(amountStr, 64) if err != nil { return nil, &ImportError{Row: rowNum, Column: "amount", Message: "invalid amount format"} } row.Amount = amount // Validate required fields if row.Date == "" { return nil, &ImportError{Row: rowNum, Column: "date", Message: "date is required"} } if row.Type == "" { return nil, &ImportError{Row: rowNum, Column: "type", Message: "type is required"} } if row.Category == "" { return nil, &ImportError{Row: rowNum, Column: "category", Message: "category is required"} } if row.Account == "" { return nil, &ImportError{Row: rowNum, Column: "account", Message: "account is required"} } return row, nil } // createTransaction creates a transaction from import row data func (s *ImportService) createTransaction(userID uint, row *TransactionImportRow, rowNum int) (uint, *ImportError) { // Parse date date, err := time.Parse("2006-01-02", row.Date) if err != nil { // Try alternative formats date, err = time.Parse("2006/01/02", row.Date) if err != nil { return 0, &ImportError{Row: rowNum, Column: "date", Message: "invalid date format, expected YYYY-MM-DD"} } } // Parse transaction type txType, err := s.parseTransactionType(row.Type) if err != nil { return 0, &ImportError{Row: rowNum, Column: "type", Message: err.Error()} } // Find category by name category, err := s.categoryRepo.GetByName(userID, row.Category) if err != nil { return 0, &ImportError{Row: rowNum, Column: "category", Message: fmt.Sprintf("category '%s' not found", row.Category)} } // Find account by name account, err := s.accountRepo.GetByName(userID, row.Account) if err != nil { return 0, &ImportError{Row: rowNum, Column: "account", Message: fmt.Sprintf("account '%s' not found", row.Account)} } // Parse currency currency := models.CurrencyCNY if row.Currency != "" { currency = models.Currency(strings.ToUpper(row.Currency)) } // Create transaction tx := &models.Transaction{ UserID: userID, Amount: row.Amount, Type: txType, CategoryID: category.ID, AccountID: account.ID, Currency: currency, TransactionDate: date, Note: row.Note, } // Handle transfer transactions if txType == models.TransactionTypeTransfer && row.ToAccount != "" { toAccount, err := s.accountRepo.GetByName(userID, row.ToAccount) if err != nil { return 0, &ImportError{Row: rowNum, Column: "to_account", Message: fmt.Sprintf("to_account '%s' not found", row.ToAccount)} } tx.ToAccountID = &toAccount.ID } // Save transaction if err := s.transactionRepo.Create(tx); err != nil { return 0, &ImportError{Row: rowNum, Message: fmt.Sprintf("failed to create transaction: %v", err)} } return tx.ID, nil } // parseTransactionType converts string to TransactionType func (s *ImportService) parseTransactionType(typeStr string) (models.TransactionType, error) { switch strings.ToLower(typeStr) { case "income", "收入": return models.TransactionTypeIncome, nil case "expense", "支出": return models.TransactionTypeExpense, nil case "transfer", "转账": return models.TransactionTypeTransfer, nil default: return "", fmt.Errorf("invalid transaction type '%s', expected income/expense/transfer", typeStr) } } // GenerateCSVTemplate generates a CSV template for import func (s *ImportService) GenerateCSVTemplate() string { header := "date,amount,type,category,account,note,currency,to_account\n" example := "2024-01-15,100.00,expense,餐饮,现金,午餐,CNY,\n" example += "2024-01-16,5000.00,income,工资,银行�?月薪,CNY,\n" example += "2024-01-17,200.00,transfer,转账,银行�?转到支付�?CNY,支付宝\n" return header + example } // ValidateImportData validates import data without creating transactions func (s *ImportService) ValidateImportData(userID uint, reader io.Reader) (*ImportResult, error) { csvReader := csv.NewReader(reader) csvReader.TrimLeadingSpace = true // Read header row header, err := csvReader.Read() if err != nil { if err == io.EOF { return nil, ErrEmptyFile } return nil, fmt.Errorf("failed to read header: %w", err) } // Validate and map header columns columnMap, err := s.parseHeader(header) if err != nil { return nil, err } result := &ImportResult{ Errors: make([]ImportError, 0), } rowNum := 1 for { record, err := csvReader.Read() if err == io.EOF { break } if err != nil { result.Errors = append(result.Errors, ImportError{ Row: rowNum, Message: fmt.Sprintf("failed to read row: %v", err), }) result.FailedCount++ rowNum++ continue } result.TotalRows++ rowNum++ // Parse and validate row data row, parseErr := s.parseRow(record, columnMap, rowNum) if parseErr != nil { result.Errors = append(result.Errors, *parseErr) result.FailedCount++ continue } // Validate references exist if validateErr := s.validateRow(userID, row, rowNum); validateErr != nil { result.Errors = append(result.Errors, *validateErr) result.FailedCount++ continue } result.SuccessCount++ } return result, nil } // validateRow validates that all references in a row exist func (s *ImportService) validateRow(userID uint, row *TransactionImportRow, rowNum int) *ImportError { // Validate date format _, err := time.Parse("2006-01-02", row.Date) if err != nil { _, err = time.Parse("2006/01/02", row.Date) if err != nil { return &ImportError{Row: rowNum, Column: "date", Message: "invalid date format"} } } // Validate transaction type if _, err := s.parseTransactionType(row.Type); err != nil { return &ImportError{Row: rowNum, Column: "type", Message: err.Error()} } // Validate category exists if _, err := s.categoryRepo.GetByName(userID, row.Category); err != nil { return &ImportError{Row: rowNum, Column: "category", Message: fmt.Sprintf("category '%s' not found", row.Category)} } // Validate account exists if _, err := s.accountRepo.GetByName(userID, row.Account); err != nil { return &ImportError{Row: rowNum, Column: "account", Message: fmt.Sprintf("account '%s' not found", row.Account)} } // Validate to_account for transfers if strings.ToLower(row.Type) == "transfer" && row.ToAccount != "" { if _, err := s.accountRepo.GetByName(userID, row.ToAccount); err != nil { return &ImportError{Row: rowNum, Column: "to_account", Message: fmt.Sprintf("to_account '%s' not found", row.ToAccount)} } } return nil }