Files
Novault-backend/internal/service/ai_bookkeeping_service.go
2026-01-25 21:59:00 +08:00

1001 lines
29 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"accounting-app/internal/config"
"accounting-app/internal/models"
"accounting-app/internal/repository"
"gorm.io/gorm"
)
// TranscriptionResult represents the result of audio transcription
type TranscriptionResult struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}
// AITransactionParams represents parsed transaction parameters
type AITransactionParams struct {
Amount *float64 `json:"amount,omitempty"`
Category string `json:"category,omitempty"`
CategoryID *uint `json:"category_id,omitempty"`
Account string `json:"account,omitempty"`
AccountID *uint `json:"account_id,omitempty"`
Type string `json:"type,omitempty"` // "expense" or "income"
Date string `json:"date,omitempty"`
Note string `json:"note,omitempty"`
}
// ConfirmationCard represents a transaction confirmation card
type ConfirmationCard struct {
SessionID string `json:"session_id"`
Amount float64 `json:"amount"`
Category string `json:"category"`
CategoryID uint `json:"category_id"`
Account string `json:"account"`
AccountID uint `json:"account_id"`
Type string `json:"type"`
Date string `json:"date"`
Note string `json:"note,omitempty"`
IsComplete bool `json:"is_complete"`
}
// AIChatResponse represents the response from AI chat
type AIChatResponse struct {
SessionID string `json:"session_id"`
Message string `json:"message"`
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
Params *AITransactionParams `json:"params,omitempty"`
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
NeedsFollowUp bool `json:"needs_follow_up"`
FollowUpQuestion string `json:"follow_up_question,omitempty"`
}
// AISession represents an AI conversation session
type AISession struct {
ID string
UserID uint
Params *AITransactionParams
Messages []ChatMessage
CreatedAt time.Time
ExpiresAt time.Time
}
// ChatMessage represents a message in the conversation
type ChatMessage struct {
Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"`
}
// WhisperService handles audio transcription
type WhisperService struct {
config *config.Config
httpClient *http.Client
}
// NewWhisperService creates a new WhisperService
func NewWhisperService(cfg *config.Config) *WhisperService {
return &WhisperService{
config: cfg,
httpClient: &http.Client{
Timeout: 120 * time.Second, // Increased timeout for audio transcription
},
}
}
// TranscribeAudio transcribes audio file to text using Whisper API
// Supports formats: mp3, wav, m4a, webm
// Requirements: 6.1-6.7
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
if s.config.OpenAIAPIKey == "" {
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
}
if s.config.OpenAIBaseURL == "" {
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
}
// Validate file format
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
if !validFormats[ext] {
return nil, fmt.Errorf("unsupported audio format: %s", ext)
}
// Create multipart form
var buf bytes.Buffer
writer := multipart.NewWriter(&buf)
// Add audio file
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, fmt.Errorf("failed to create form file: %w", err)
}
if _, err := io.Copy(part, audioData); err != nil {
return nil, fmt.Errorf("failed to copy audio data: %w", err)
}
// Add model field
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
return nil, fmt.Errorf("failed to write model field: %w", err)
}
// Add language hint for Chinese
if err := writer.WriteField("language", "zh"); err != nil {
return nil, fmt.Errorf("failed to write language field: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("failed to close writer: %w", err)
}
// Create request
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", writer.FormDataContentType())
// Send request
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("transcription request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result TranscriptionResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
}
// LLMService handles natural language understanding
type LLMService struct {
config *config.Config
httpClient *http.Client
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
}
// NewLLMService creates a new LLMService
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
return &LLMService{
config: cfg,
httpClient: &http.Client{
Timeout: 60 * time.Second, // Increased timeout for slow API responses
},
accountRepo: accountRepo,
categoryRepo: categoryRepo,
}
}
// ChatCompletionRequest represents OpenAI chat completion request
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Functions []Function `json:"functions,omitempty"`
Temperature float64 `json:"temperature"`
}
// Function represents an OpenAI function definition
type Function struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ChatCompletionResponse represents OpenAI chat completion response
type ChatCompletionResponse struct {
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
FunctionCall *struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function_call,omitempty"`
} `json:"message"`
} `json:"choices"`
}
// ParseIntent extracts transaction parameters from text
// Requirements: 7.1, 7.5, 7.6
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) {
// Fast path: try simple parsing first for common patterns
// This avoids LLM call for simple inputs like "6块钱奶茶"
// TODO: 暂时禁用本地解析快速路径,始终使用 LLM
// simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
// if simpleParams != nil && simpleParams.Amount != nil && simpleParams.Category != "" && simpleParams.Category != "其他" {
// // Simple parsing succeeded with amount and category, use it directly
// return simpleParams, simpleMsg, nil
// }
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
// No API key, return simple parsing result
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
return simpleParams, simpleMsg, nil
}
// Build messages with history
todayDate := time.Now().Format("2006-01-02")
systemPrompt := fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息<EFBFBD>?
今天的日期是<EFBFBD>?s
规则<EFBFBD>?
1. 金额提取数字<E5AD97>?6<>?=6<>?十五<E58D81>?=15
2. 分类根据内容推断<E696AD>?奶茶/咖啡/吃饭"=餐饮<E9A490>?打车/地铁"=交通,"买衣<E4B9B0>?=购物
3. 类型默认expense(支出),除非明确说"收入/工资/奖金/红包"
4. 日期默认使用今天的日期<E697A5>?s除非用户明确指定其他日期
5. 备注提取关键描<E994AE>?
直接返回JSON不要解释
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确<E79FAD>?}
示例(假设今天是%s
用户<EFBFBD>?买了<E4B9B0>?块的奶茶"
返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录餐饮支<E9A5AE>?元,奶茶"}`, todayDate, todayDate, todayDate, todayDate)
messages := []ChatMessage{
{
Role: "system",
Content: systemPrompt,
},
}
// Only add last 2 messages from history to reduce context
historyLen := len(history)
if historyLen > 4 {
history = history[historyLen-4:]
}
messages = append(messages, history...)
// Add current user message
messages = append(messages, ChatMessage{
Role: "user",
Content: text,
})
// Create request
reqBody := ChatCompletionRequest{
Model: s.config.ChatModel,
Messages: messages,
Temperature: 0.1, // Lower temperature for more consistent output
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, "", fmt.Errorf("chat request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
}
var chatResp ChatCompletionResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, "", fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, "", errors.New("no response from AI")
}
content := chatResp.Choices[0].Message.Content
// Remove markdown code block if present (```json ... ```)
content = strings.TrimSpace(content)
if strings.HasPrefix(content, "```") {
// Find the end of the first line (```json or ```)
if idx := strings.Index(content, "\n"); idx != -1 {
content = content[idx+1:]
}
// Remove trailing ```
if idx := strings.LastIndex(content, "```"); idx != -1 {
content = content[:idx]
}
content = strings.TrimSpace(content)
}
// Parse JSON response
var parsed struct {
Amount *float64 `json:"amount"`
Category string `json:"category"`
Type string `json:"type"`
Note string `json:"note"`
Date string `json:"date"`
Message string `json:"message"`
}
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
// If not JSON, return as message
return nil, content, nil
}
params := &AITransactionParams{
Amount: parsed.Amount,
Category: parsed.Category,
Type: parsed.Type,
Note: parsed.Note,
Date: parsed.Date,
}
return params, parsed.Message, nil
}
// parseIntentSimple provides simple regex-based parsing as fallback
// This is also used as a fast path for simple inputs
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
params := &AITransactionParams{
Type: "expense", // Default to expense
Date: time.Now().Format("2006-01-02"),
}
// Extract amount using regex - support various formats
amountPatterns := []string{
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
}
for _, pattern := range amountPatterns {
amountRegex := regexp.MustCompile(pattern)
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
break
}
}
}
// If still no amount, try simple number extraction
if params.Amount == nil {
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
params.Amount = &amount
}
}
}
// Enhanced category detection with priority
categoryPatterns := []struct {
keywords []string
category string
}{
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
{[]string{"买", "购"}, "购物"},
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
{[]string{"房租", "租金", "房贷"}, "住房"},
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
{[]string{"红包", "转账", "收款"}, "其他收入"},
}
for _, cp := range categoryPatterns {
for _, keyword := range cp.keywords {
if strings.Contains(text, keyword) {
params.Category = cp.category
break
}
}
if params.Category != "" {
break
}
}
// Default category if not detected
if params.Category == "" {
params.Category = "其他"
}
// Detect income keywords
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
for _, keyword := range incomeKeywords {
if strings.Contains(text, keyword) {
params.Type = "income"
break
}
}
// Extract note - remove amount and common words
note := text
if params.Amount != nil {
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
}
note = strings.TrimSpace(note)
// Remove common filler words
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
for _, word := range fillerWords {
note = strings.ReplaceAll(note, word, "")
}
note = strings.TrimSpace(note)
if note != "" {
params.Note = note
}
// Generate response message
var message string
if params.Amount == nil {
message = "请问金额是多少?"
} else {
typeLabel := "支出"
if params.Type == "income" {
typeLabel = "收入"
}
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
if params.Note != "" {
message += ",备注:" + params.Note
}
}
return params, message, nil
}
// MapAccountName maps natural language account name to account ID
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
accounts, err := s.accountRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Try exact match first
for _, acc := range accounts {
if strings.EqualFold(acc.Name, name) {
return &acc.ID, acc.Name, nil
}
}
// Try partial match
for _, acc := range accounts {
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
return &acc.ID, acc.Name, nil
}
}
return nil, "", nil
}
// MapCategoryName maps natural language category name to category ID
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
if name == "" {
return nil, "", nil
}
categories, err := s.categoryRepo.GetAll(userID)
if err != nil {
return nil, "", err
}
// Filter by transaction type
var filtered []models.Category
for _, cat := range categories {
if (txType == "expense" && cat.Type == "expense") ||
(txType == "income" && cat.Type == "income") ||
txType == "" {
filtered = append(filtered, cat)
}
}
// Try exact match first
for _, cat := range filtered {
if strings.EqualFold(cat.Name, name) {
return &cat.ID, cat.Name, nil
}
}
// Try partial match
for _, cat := range filtered {
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
return &cat.ID, cat.Name, nil
}
}
return nil, "", nil
}
// AIBookkeepingService orchestrates AI bookkeeping functionality
type AIBookkeepingService struct {
whisperService *WhisperService
llmService *LLMService
transactionRepo *repository.TransactionRepository
accountRepo *repository.AccountRepository
categoryRepo *repository.CategoryRepository
userSettingsRepo *repository.UserSettingsRepository
db *gorm.DB
sessions map[string]*AISession
sessionMutex sync.RWMutex
config *config.Config
}
// NewAIBookkeepingService creates a new AIBookkeepingService
func NewAIBookkeepingService(
cfg *config.Config,
transactionRepo *repository.TransactionRepository,
accountRepo *repository.AccountRepository,
categoryRepo *repository.CategoryRepository,
userSettingsRepo *repository.UserSettingsRepository,
db *gorm.DB,
) *AIBookkeepingService {
whisperService := NewWhisperService(cfg)
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
svc := &AIBookkeepingService{
whisperService: whisperService,
llmService: llmService,
transactionRepo: transactionRepo,
accountRepo: accountRepo,
categoryRepo: categoryRepo,
userSettingsRepo: userSettingsRepo,
db: db,
sessions: make(map[string]*AISession),
config: cfg,
}
// Start session cleanup goroutine
go svc.cleanupExpiredSessions()
return svc
}
// generateSessionID generates a unique session ID
func generateSessionID() string {
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
}
// getOrCreateSession gets existing session or creates new one
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
s.sessionMutex.Lock()
defer s.sessionMutex.Unlock()
if sessionID != "" {
if session, ok := s.sessions[sessionID]; ok {
if time.Now().Before(session.ExpiresAt) {
return session
}
delete(s.sessions, sessionID)
}
}
// Create new session
newID := generateSessionID()
session := &AISession{
ID: newID,
UserID: userID,
Params: &AITransactionParams{},
Messages: []ChatMessage{},
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
}
s.sessions[newID] = session
return session
}
// cleanupExpiredSessions periodically removes expired sessions
func (s *AIBookkeepingService) cleanupExpiredSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
s.sessionMutex.Lock()
now := time.Now()
for id, session := range s.sessions {
if now.After(session.ExpiresAt) {
delete(s.sessions, id)
}
}
s.sessionMutex.Unlock()
}
}
// ProcessChat processes a chat message and returns AI response
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
session := s.getOrCreateSession(sessionID, userID)
// Add user message to history
session.Messages = append(session.Messages, ChatMessage{
Role: "user",
Content: message,
})
// Parse intent
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1])
if err != nil {
return nil, fmt.Errorf("failed to parse intent: %w", err)
}
// Merge with existing session params
if params != nil {
s.mergeParams(session.Params, params)
}
// Map account and category names to IDs
if session.Params.Account != "" && session.Params.AccountID == nil {
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
if accountID != nil {
session.Params.AccountID = accountID
session.Params.Account = accountName
}
}
if session.Params.Category != "" && session.Params.CategoryID == nil {
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
if categoryID != nil {
session.Params.CategoryID = categoryID
session.Params.Category = categoryName
}
}
// If category still not mapped, try to get a default category
if session.Params.CategoryID == nil && session.Params.Category != "" {
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
if defaultCategoryID != nil {
session.Params.CategoryID = defaultCategoryID
// Keep the original category name from AI, just set the ID
if session.Params.Category == "" {
session.Params.Category = defaultCategoryName
}
}
}
// If no account specified, use default account
if session.Params.AccountID == nil {
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
if defaultAccountID != nil {
session.Params.AccountID = defaultAccountID
session.Params.Account = defaultAccountName
}
}
// Check if we have all required params
response := &AIChatResponse{
SessionID: session.ID,
Message: responseMsg,
Intent: "create_transaction",
Params: session.Params,
}
// Check what's missing
missingFields := s.getMissingFields(session.Params)
if len(missingFields) > 0 {
response.NeedsFollowUp = true
response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields)
if responseMsg == "" {
response.Message = response.FollowUpQuestion
}
} else {
// All params complete, generate confirmation card
card := s.GenerateConfirmationCard(session)
response.ConfirmationCard = card
response.Message = fmt.Sprintf("请确认:%s %.2f元分类<EFBFBD>?s账户%s",
s.getTypeLabel(session.Params.Type),
*session.Params.Amount,
session.Params.Category,
session.Params.Account)
}
// Add assistant response to history
session.Messages = append(session.Messages, ChatMessage{
Role: "assistant",
Content: response.Message,
})
return response, nil
}
// mergeParams merges new params into existing params
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
if new.Amount != nil {
existing.Amount = new.Amount
}
if new.Category != "" {
existing.Category = new.Category
}
if new.CategoryID != nil {
existing.CategoryID = new.CategoryID
}
if new.Account != "" {
existing.Account = new.Account
}
if new.AccountID != nil {
existing.AccountID = new.AccountID
}
if new.Type != "" {
existing.Type = new.Type
}
if new.Date != "" {
existing.Date = new.Date
}
if new.Note != "" {
existing.Note = new.Note
}
}
// getDefaultAccount gets the default account based on transaction type
// If no default is set, returns the first available account
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
// First try to get user's configured default account
settings, err := s.userSettingsRepo.GetOrCreate(userID)
if err == nil && settings != nil {
var accountID *uint
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
accountID = settings.DefaultExpenseAccountID
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
accountID = settings.DefaultIncomeAccountID
}
if accountID != nil {
account, err := s.accountRepo.GetByID(userID, *accountID)
if err == nil && account != nil {
return accountID, account.Name
}
}
}
// Fallback: get the first available account
accounts, err := s.accountRepo.GetAll(userID)
if err != nil || len(accounts) == 0 {
return nil, ""
}
// Return the first account (usually sorted by sort_order)
return &accounts[0].ID, accounts[0].Name
}
// getDefaultCategory gets the first category of the given type
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
categories, err := s.categoryRepo.GetAll(userID)
if err != nil || len(categories) == 0 {
return nil, ""
}
// Find the first category matching the transaction type
categoryType := "expense"
if txType == "income" {
categoryType = "income"
}
for _, cat := range categories {
if string(cat.Type) == categoryType {
return &cat.ID, cat.Name
}
}
// If no matching type found, return the first category
return &categories[0].ID, categories[0].Name
}
// getMissingFields returns list of missing required fields
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
var missing []string
if params.Amount == nil {
missing = append(missing, "amount")
}
if params.CategoryID == nil && params.Category == "" {
missing = append(missing, "category")
}
if params.AccountID == nil && params.Account == "" {
missing = append(missing, "account")
}
return missing
}
// generateFollowUpQuestion generates a follow-up question for missing fields
func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string {
if len(missing) == 0 {
return ""
}
fieldNames := map[string]string{
"amount": "金额",
"category": "分类",
"account": "账户",
}
var names []string
for _, field := range missing {
if name, ok := fieldNames[field]; ok {
names = append(names, name)
}
}
if len(names) == 1 {
return fmt.Sprintf("请问%s是多少", names[0])
}
return fmt.Sprintf("请补充以下信息:%s", strings.Join(names, "、"))
}
// getTypeLabel returns Chinese label for transaction type
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
if txType == "income" {
return "收入"
}
return "支出"
}
// GenerateConfirmationCard creates a confirmation card from session params
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
params := session.Params
card := &ConfirmationCard{
SessionID: session.ID,
Type: params.Type,
Note: params.Note,
IsComplete: true,
}
if params.Amount != nil {
card.Amount = *params.Amount
}
if params.CategoryID != nil {
card.CategoryID = *params.CategoryID
}
card.Category = params.Category
if params.AccountID != nil {
card.AccountID = *params.AccountID
}
card.Account = params.Account
// Set date
if params.Date != "" {
card.Date = params.Date
} else {
card.Date = time.Now().Format("2006-01-02")
}
return card
}
// TranscribeAudio transcribes audio and returns text
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
}
// ConfirmTransaction creates a transaction from confirmed card
// Requirements: 7.10
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
s.sessionMutex.RLock()
session, ok := s.sessions[sessionID]
s.sessionMutex.RUnlock()
if !ok {
return nil, errors.New("session not found or expired")
}
params := session.Params
// Validate required fields
if params.Amount == nil || *params.Amount <= 0 {
return nil, errors.New("invalid amount")
}
if params.CategoryID == nil {
return nil, errors.New("category not specified")
}
if params.AccountID == nil {
return nil, errors.New("account not specified")
}
// Parse date
var txDate time.Time
if params.Date != "" {
var err error
txDate, err = time.Parse("2006-01-02", params.Date)
if err != nil {
txDate = time.Now()
}
} else {
txDate = time.Now()
}
// Determine transaction type
txType := models.TransactionTypeExpense
if params.Type == "income" {
txType = models.TransactionTypeIncome
}
// Create transaction
tx := &models.Transaction{
UserID: userID,
Amount: *params.Amount,
Type: txType,
CategoryID: *params.CategoryID,
AccountID: *params.AccountID,
TransactionDate: txDate,
Note: params.Note,
Currency: "CNY",
}
// Save transaction
if err := s.transactionRepo.Create(tx); err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}
// Update account balance
account, err := s.accountRepo.GetByID(userID, *params.AccountID)
if err != nil {
return nil, fmt.Errorf("failed to find account: %w", err)
}
if txType == models.TransactionTypeExpense {
account.Balance -= *params.Amount
} else {
account.Balance += *params.Amount
}
if err := s.accountRepo.Update(account); err != nil {
return nil, fmt.Errorf("failed to update account balance: %w", err)
}
// Clean up session
s.sessionMutex.Lock()
delete(s.sessions, sessionID)
s.sessionMutex.Unlock()
return tx, nil
}
// GetSession returns session by ID
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
s.sessionMutex.RLock()
defer s.sessionMutex.RUnlock()
session, ok := s.sessions[sessionID]
if !ok || time.Now().After(session.ExpiresAt) {
return nil, false
}
return session, true
}