1001 lines
29 KiB
Go
1001 lines
29 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"accounting-app/internal/config"
|
||
"accounting-app/internal/models"
|
||
"accounting-app/internal/repository"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// TranscriptionResult represents the result of audio transcription
|
||
type TranscriptionResult struct {
|
||
Text string `json:"text"`
|
||
Language string `json:"language,omitempty"`
|
||
Duration float64 `json:"duration,omitempty"`
|
||
}
|
||
|
||
// AITransactionParams represents parsed transaction parameters
|
||
type AITransactionParams struct {
|
||
Amount *float64 `json:"amount,omitempty"`
|
||
Category string `json:"category,omitempty"`
|
||
CategoryID *uint `json:"category_id,omitempty"`
|
||
Account string `json:"account,omitempty"`
|
||
AccountID *uint `json:"account_id,omitempty"`
|
||
Type string `json:"type,omitempty"` // "expense" or "income"
|
||
Date string `json:"date,omitempty"`
|
||
Note string `json:"note,omitempty"`
|
||
}
|
||
|
||
// ConfirmationCard represents a transaction confirmation card
|
||
type ConfirmationCard struct {
|
||
SessionID string `json:"session_id"`
|
||
Amount float64 `json:"amount"`
|
||
Category string `json:"category"`
|
||
CategoryID uint `json:"category_id"`
|
||
Account string `json:"account"`
|
||
AccountID uint `json:"account_id"`
|
||
Type string `json:"type"`
|
||
Date string `json:"date"`
|
||
Note string `json:"note,omitempty"`
|
||
IsComplete bool `json:"is_complete"`
|
||
}
|
||
|
||
// AIChatResponse represents the response from AI chat
|
||
type AIChatResponse struct {
|
||
SessionID string `json:"session_id"`
|
||
Message string `json:"message"`
|
||
Intent string `json:"intent,omitempty"` // "create_transaction", "query", "unknown"
|
||
Params *AITransactionParams `json:"params,omitempty"`
|
||
ConfirmationCard *ConfirmationCard `json:"confirmation_card,omitempty"`
|
||
NeedsFollowUp bool `json:"needs_follow_up"`
|
||
FollowUpQuestion string `json:"follow_up_question,omitempty"`
|
||
}
|
||
|
||
// AISession represents an AI conversation session
|
||
type AISession struct {
|
||
ID string
|
||
UserID uint
|
||
Params *AITransactionParams
|
||
Messages []ChatMessage
|
||
CreatedAt time.Time
|
||
ExpiresAt time.Time
|
||
}
|
||
|
||
// ChatMessage represents a message in the conversation
|
||
type ChatMessage struct {
|
||
Role string `json:"role"` // "user", "assistant", "system"
|
||
Content string `json:"content"`
|
||
}
|
||
|
||
// WhisperService handles audio transcription
|
||
type WhisperService struct {
|
||
config *config.Config
|
||
httpClient *http.Client
|
||
}
|
||
|
||
// NewWhisperService creates a new WhisperService
|
||
func NewWhisperService(cfg *config.Config) *WhisperService {
|
||
return &WhisperService{
|
||
config: cfg,
|
||
httpClient: &http.Client{
|
||
Timeout: 120 * time.Second, // Increased timeout for audio transcription
|
||
},
|
||
}
|
||
}
|
||
|
||
// TranscribeAudio transcribes audio file to text using Whisper API
|
||
// Supports formats: mp3, wav, m4a, webm
|
||
// Requirements: 6.1-6.7
|
||
func (s *WhisperService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
|
||
if s.config.OpenAIAPIKey == "" {
|
||
return nil, errors.New("OpenAI API key not configured (OPENAI_API_KEY)")
|
||
}
|
||
if s.config.OpenAIBaseURL == "" {
|
||
return nil, errors.New("OpenAI base URL not configured (OPENAI_BASE_URL)")
|
||
}
|
||
|
||
// Validate file format
|
||
ext := strings.ToLower(filename[strings.LastIndex(filename, ".")+1:])
|
||
validFormats := map[string]bool{"mp3": true, "wav": true, "m4a": true, "webm": true, "ogg": true, "flac": true}
|
||
if !validFormats[ext] {
|
||
return nil, fmt.Errorf("unsupported audio format: %s", ext)
|
||
}
|
||
|
||
// Create multipart form
|
||
var buf bytes.Buffer
|
||
writer := multipart.NewWriter(&buf)
|
||
|
||
// Add audio file
|
||
part, err := writer.CreateFormFile("file", filename)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create form file: %w", err)
|
||
}
|
||
|
||
if _, err := io.Copy(part, audioData); err != nil {
|
||
return nil, fmt.Errorf("failed to copy audio data: %w", err)
|
||
}
|
||
|
||
// Add model field
|
||
if err := writer.WriteField("model", s.config.WhisperModel); err != nil {
|
||
return nil, fmt.Errorf("failed to write model field: %w", err)
|
||
}
|
||
|
||
// Add language hint for Chinese
|
||
if err := writer.WriteField("language", "zh"); err != nil {
|
||
return nil, fmt.Errorf("failed to write language field: %w", err)
|
||
}
|
||
|
||
if err := writer.Close(); err != nil {
|
||
return nil, fmt.Errorf("failed to close writer: %w", err)
|
||
}
|
||
|
||
// Create request
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/audio/transcriptions", &buf)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||
|
||
// Send request
|
||
resp, err := s.httpClient.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("transcription request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, fmt.Errorf("transcription failed with status %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// Parse response
|
||
var result TranscriptionResult
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||
}
|
||
|
||
return &result, nil
|
||
}
|
||
|
||
// LLMService handles natural language understanding
|
||
type LLMService struct {
|
||
config *config.Config
|
||
httpClient *http.Client
|
||
accountRepo *repository.AccountRepository
|
||
categoryRepo *repository.CategoryRepository
|
||
}
|
||
|
||
// NewLLMService creates a new LLMService
|
||
func NewLLMService(cfg *config.Config, accountRepo *repository.AccountRepository, categoryRepo *repository.CategoryRepository) *LLMService {
|
||
return &LLMService{
|
||
config: cfg,
|
||
httpClient: &http.Client{
|
||
Timeout: 60 * time.Second, // Increased timeout for slow API responses
|
||
},
|
||
accountRepo: accountRepo,
|
||
categoryRepo: categoryRepo,
|
||
}
|
||
}
|
||
|
||
// ChatCompletionRequest represents OpenAI chat completion request
|
||
type ChatCompletionRequest struct {
|
||
Model string `json:"model"`
|
||
Messages []ChatMessage `json:"messages"`
|
||
Functions []Function `json:"functions,omitempty"`
|
||
Temperature float64 `json:"temperature"`
|
||
}
|
||
|
||
// Function represents an OpenAI function definition
|
||
type Function struct {
|
||
Name string `json:"name"`
|
||
Description string `json:"description"`
|
||
Parameters map[string]interface{} `json:"parameters"`
|
||
}
|
||
|
||
// ChatCompletionResponse represents OpenAI chat completion response
|
||
type ChatCompletionResponse struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
FunctionCall *struct {
|
||
Name string `json:"name"`
|
||
Arguments string `json:"arguments"`
|
||
} `json:"function_call,omitempty"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
// ParseIntent extracts transaction parameters from text
|
||
// Requirements: 7.1, 7.5, 7.6
|
||
func (s *LLMService) ParseIntent(ctx context.Context, text string, history []ChatMessage) (*AITransactionParams, string, error) {
|
||
// Fast path: try simple parsing first for common patterns
|
||
// This avoids LLM call for simple inputs like "6块钱奶茶"
|
||
// TODO: 暂时禁用本地解析快速路径,始终使用 LLM
|
||
// simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
|
||
// if simpleParams != nil && simpleParams.Amount != nil && simpleParams.Category != "" && simpleParams.Category != "其他" {
|
||
// // Simple parsing succeeded with amount and category, use it directly
|
||
// return simpleParams, simpleMsg, nil
|
||
// }
|
||
|
||
if s.config.OpenAIAPIKey == "" || s.config.OpenAIBaseURL == "" {
|
||
// No API key, return simple parsing result
|
||
simpleParams, simpleMsg, _ := s.parseIntentSimple(text)
|
||
return simpleParams, simpleMsg, nil
|
||
}
|
||
|
||
// Build messages with history
|
||
todayDate := time.Now().Format("2006-01-02")
|
||
systemPrompt := fmt.Sprintf(`你是一个智能记账助手。从用户描述中提取记账信息<EFBFBD>?
|
||
|
||
今天的日期是<EFBFBD>?s
|
||
|
||
规则<EFBFBD>?
|
||
1. 金额:提取数字,<E5AD97>?6<>?=6<>?十五<E58D81>?=15
|
||
2. 分类:根据内容推断,<E696AD>?奶茶/咖啡/吃饭"=餐饮<E9A490>?打车/地铁"=交通,"买衣<E4B9B0>?=购物
|
||
3. 类型:默认expense(支出),除非明确说"收入/工资/奖金/红包"
|
||
4. 日期:默认使用今天的日期<E697A5>?s),除非用户明确指定其他日期
|
||
5. 备注:提取关键描<E994AE>?
|
||
|
||
直接返回JSON,不要解释:
|
||
{"amount":数字,"category":"分类","type":"expense或income","note":"备注","date":"YYYY-MM-DD","message":"简短确<E79FAD>?}
|
||
|
||
示例(假设今天是%s):
|
||
用户<EFBFBD>?买了<E4B9B0>?块的奶茶"
|
||
返回:{"amount":6,"category":"餐饮","type":"expense","note":"奶茶","date":"%s","message":"记录:餐饮支<E9A5AE>?元,奶茶"}`, todayDate, todayDate, todayDate, todayDate)
|
||
|
||
messages := []ChatMessage{
|
||
{
|
||
Role: "system",
|
||
Content: systemPrompt,
|
||
},
|
||
}
|
||
|
||
// Only add last 2 messages from history to reduce context
|
||
historyLen := len(history)
|
||
if historyLen > 4 {
|
||
history = history[historyLen-4:]
|
||
}
|
||
messages = append(messages, history...)
|
||
|
||
// Add current user message
|
||
messages = append(messages, ChatMessage{
|
||
Role: "user",
|
||
Content: text,
|
||
})
|
||
|
||
// Create request
|
||
reqBody := ChatCompletionRequest{
|
||
Model: s.config.ChatModel,
|
||
Messages: messages,
|
||
Temperature: 0.1, // Lower temperature for more consistent output
|
||
}
|
||
|
||
jsonBody, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to marshal request: %w", err)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", s.config.OpenAIBaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("failed to create request: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.OpenAIAPIKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.httpClient.Do(req)
|
||
if err != nil {
|
||
return nil, "", fmt.Errorf("chat request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return nil, "", fmt.Errorf("chat failed with status %d: %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
var chatResp ChatCompletionResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||
return nil, "", fmt.Errorf("failed to decode response: %w", err)
|
||
}
|
||
|
||
if len(chatResp.Choices) == 0 {
|
||
return nil, "", errors.New("no response from AI")
|
||
}
|
||
|
||
content := chatResp.Choices[0].Message.Content
|
||
|
||
// Remove markdown code block if present (```json ... ```)
|
||
content = strings.TrimSpace(content)
|
||
if strings.HasPrefix(content, "```") {
|
||
// Find the end of the first line (```json or ```)
|
||
if idx := strings.Index(content, "\n"); idx != -1 {
|
||
content = content[idx+1:]
|
||
}
|
||
// Remove trailing ```
|
||
if idx := strings.LastIndex(content, "```"); idx != -1 {
|
||
content = content[:idx]
|
||
}
|
||
content = strings.TrimSpace(content)
|
||
}
|
||
|
||
// Parse JSON response
|
||
var parsed struct {
|
||
Amount *float64 `json:"amount"`
|
||
Category string `json:"category"`
|
||
Type string `json:"type"`
|
||
Note string `json:"note"`
|
||
Date string `json:"date"`
|
||
Message string `json:"message"`
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
|
||
// If not JSON, return as message
|
||
return nil, content, nil
|
||
}
|
||
|
||
params := &AITransactionParams{
|
||
Amount: parsed.Amount,
|
||
Category: parsed.Category,
|
||
Type: parsed.Type,
|
||
Note: parsed.Note,
|
||
Date: parsed.Date,
|
||
}
|
||
|
||
return params, parsed.Message, nil
|
||
}
|
||
|
||
// parseIntentSimple provides simple regex-based parsing as fallback
|
||
// This is also used as a fast path for simple inputs
|
||
func (s *LLMService) parseIntentSimple(text string) (*AITransactionParams, string, error) {
|
||
params := &AITransactionParams{
|
||
Type: "expense", // Default to expense
|
||
Date: time.Now().Format("2006-01-02"),
|
||
}
|
||
|
||
// Extract amount using regex - support various formats
|
||
amountPatterns := []string{
|
||
`(\d+(?:\.\d+)?)\s*(?:元|块|¥|¥|块钱|元钱)`,
|
||
`(?:花了?|付了?|买了?|消费)\s*(\d+(?:\.\d+)?)`,
|
||
`(\d+(?:\.\d+)?)\s*(?:的|块的)`,
|
||
}
|
||
|
||
for _, pattern := range amountPatterns {
|
||
amountRegex := regexp.MustCompile(pattern)
|
||
if matches := amountRegex.FindStringSubmatch(text); len(matches) > 1 {
|
||
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
||
params.Amount = &amount
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// If still no amount, try simple number extraction
|
||
if params.Amount == nil {
|
||
simpleAmountRegex := regexp.MustCompile(`(\d+(?:\.\d+)?)`)
|
||
if matches := simpleAmountRegex.FindStringSubmatch(text); len(matches) > 1 {
|
||
if amount, err := strconv.ParseFloat(matches[1], 64); err == nil {
|
||
params.Amount = &amount
|
||
}
|
||
}
|
||
}
|
||
|
||
// Enhanced category detection with priority
|
||
categoryPatterns := []struct {
|
||
keywords []string
|
||
category string
|
||
}{
|
||
{[]string{"奶茶", "咖啡", "茶", "饮料", "柠檬", "果汁"}, "餐饮"},
|
||
{[]string{"吃", "喝", "餐", "外卖", "饭", "面", "粉", "粥", "包子", "早餐", "午餐", "晚餐", "宵夜"}, "餐饮"},
|
||
{[]string{"打车", "滴滴", "出租", "的士", "uber", "曹操"}, "交通"},
|
||
{[]string{"地铁", "公交", "公车", "巴士", "轻轨", "高铁", "火车", "飞机", "机票"}, "交通"},
|
||
{[]string{"加油", "油费", "停车", "过路费"}, "交通"},
|
||
{[]string{"超市", "便利店", "商场", "购物", "淘宝", "京东", "拼多多"}, "购物"},
|
||
{[]string{"买", "购"}, "购物"},
|
||
{[]string{"水电", "电费", "水费", "燃气", "煤气", "物业"}, "生活缴费"},
|
||
{[]string{"房租", "租金", "房贷"}, "住房"},
|
||
{[]string{"电影", "游戏", "KTV", "唱歌", "娱乐", "玩"}, "娱乐"},
|
||
{[]string{"医院", "药", "看病", "挂号", "医疗"}, "医疗"},
|
||
{[]string{"话费", "流量", "充值", "手机费"}, "通讯"},
|
||
{[]string{"工资", "薪水", "薪资", "月薪"}, "工资"},
|
||
{[]string{"奖金", "年终奖", "绩效"}, "奖金"},
|
||
{[]string{"红包", "转账", "收款"}, "其他收入"},
|
||
}
|
||
|
||
for _, cp := range categoryPatterns {
|
||
for _, keyword := range cp.keywords {
|
||
if strings.Contains(text, keyword) {
|
||
params.Category = cp.category
|
||
break
|
||
}
|
||
}
|
||
if params.Category != "" {
|
||
break
|
||
}
|
||
}
|
||
|
||
// Default category if not detected
|
||
if params.Category == "" {
|
||
params.Category = "其他"
|
||
}
|
||
|
||
// Detect income keywords
|
||
incomeKeywords := []string{"工资", "薪", "奖金", "红包", "收入", "进账", "到账", "收到", "收款"}
|
||
for _, keyword := range incomeKeywords {
|
||
if strings.Contains(text, keyword) {
|
||
params.Type = "income"
|
||
break
|
||
}
|
||
}
|
||
|
||
// Extract note - remove amount and common words
|
||
note := text
|
||
if params.Amount != nil {
|
||
note = regexp.MustCompile(`\d+(?:\.\d+)?\s*(?:元|块|¥|¥|块钱|元钱)?`).ReplaceAllString(note, "")
|
||
}
|
||
note = strings.TrimSpace(note)
|
||
// Remove common filler words
|
||
fillerWords := []string{"买了", "花了", "付了", "消费了", "一个", "一条", "一份", "的"}
|
||
for _, word := range fillerWords {
|
||
note = strings.ReplaceAll(note, word, "")
|
||
}
|
||
note = strings.TrimSpace(note)
|
||
if note != "" {
|
||
params.Note = note
|
||
}
|
||
|
||
// Generate response message
|
||
var message string
|
||
if params.Amount == nil {
|
||
message = "请问金额是多少?"
|
||
} else {
|
||
typeLabel := "支出"
|
||
if params.Type == "income" {
|
||
typeLabel = "收入"
|
||
}
|
||
message = fmt.Sprintf("记录:%s %.2f元,分类:%s", typeLabel, *params.Amount, params.Category)
|
||
if params.Note != "" {
|
||
message += ",备注:" + params.Note
|
||
}
|
||
}
|
||
|
||
return params, message, nil
|
||
}
|
||
|
||
// MapAccountName maps natural language account name to account ID
|
||
func (s *LLMService) MapAccountName(ctx context.Context, name string, userID uint) (*uint, string, error) {
|
||
if name == "" {
|
||
return nil, "", nil
|
||
}
|
||
|
||
accounts, err := s.accountRepo.GetAll(userID)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
// Try exact match first
|
||
for _, acc := range accounts {
|
||
if strings.EqualFold(acc.Name, name) {
|
||
return &acc.ID, acc.Name, nil
|
||
}
|
||
}
|
||
|
||
// Try partial match
|
||
for _, acc := range accounts {
|
||
if strings.Contains(strings.ToLower(acc.Name), strings.ToLower(name)) ||
|
||
strings.Contains(strings.ToLower(name), strings.ToLower(acc.Name)) {
|
||
return &acc.ID, acc.Name, nil
|
||
}
|
||
}
|
||
|
||
return nil, "", nil
|
||
}
|
||
|
||
// MapCategoryName maps natural language category name to category ID
|
||
func (s *LLMService) MapCategoryName(ctx context.Context, name string, txType string, userID uint) (*uint, string, error) {
|
||
if name == "" {
|
||
return nil, "", nil
|
||
}
|
||
|
||
categories, err := s.categoryRepo.GetAll(userID)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
// Filter by transaction type
|
||
var filtered []models.Category
|
||
for _, cat := range categories {
|
||
if (txType == "expense" && cat.Type == "expense") ||
|
||
(txType == "income" && cat.Type == "income") ||
|
||
txType == "" {
|
||
filtered = append(filtered, cat)
|
||
}
|
||
}
|
||
|
||
// Try exact match first
|
||
for _, cat := range filtered {
|
||
if strings.EqualFold(cat.Name, name) {
|
||
return &cat.ID, cat.Name, nil
|
||
}
|
||
}
|
||
|
||
// Try partial match
|
||
for _, cat := range filtered {
|
||
if strings.Contains(strings.ToLower(cat.Name), strings.ToLower(name)) ||
|
||
strings.Contains(strings.ToLower(name), strings.ToLower(cat.Name)) {
|
||
return &cat.ID, cat.Name, nil
|
||
}
|
||
}
|
||
|
||
return nil, "", nil
|
||
}
|
||
|
||
// AIBookkeepingService orchestrates AI bookkeeping functionality
|
||
type AIBookkeepingService struct {
|
||
whisperService *WhisperService
|
||
llmService *LLMService
|
||
transactionRepo *repository.TransactionRepository
|
||
accountRepo *repository.AccountRepository
|
||
categoryRepo *repository.CategoryRepository
|
||
userSettingsRepo *repository.UserSettingsRepository
|
||
db *gorm.DB
|
||
sessions map[string]*AISession
|
||
sessionMutex sync.RWMutex
|
||
config *config.Config
|
||
}
|
||
|
||
// NewAIBookkeepingService creates a new AIBookkeepingService
|
||
func NewAIBookkeepingService(
|
||
cfg *config.Config,
|
||
transactionRepo *repository.TransactionRepository,
|
||
accountRepo *repository.AccountRepository,
|
||
categoryRepo *repository.CategoryRepository,
|
||
userSettingsRepo *repository.UserSettingsRepository,
|
||
db *gorm.DB,
|
||
) *AIBookkeepingService {
|
||
whisperService := NewWhisperService(cfg)
|
||
llmService := NewLLMService(cfg, accountRepo, categoryRepo)
|
||
|
||
svc := &AIBookkeepingService{
|
||
whisperService: whisperService,
|
||
llmService: llmService,
|
||
transactionRepo: transactionRepo,
|
||
accountRepo: accountRepo,
|
||
categoryRepo: categoryRepo,
|
||
userSettingsRepo: userSettingsRepo,
|
||
db: db,
|
||
sessions: make(map[string]*AISession),
|
||
config: cfg,
|
||
}
|
||
|
||
// Start session cleanup goroutine
|
||
go svc.cleanupExpiredSessions()
|
||
|
||
return svc
|
||
}
|
||
|
||
// generateSessionID generates a unique session ID
|
||
func generateSessionID() string {
|
||
return fmt.Sprintf("ai_%d_%d", time.Now().UnixNano(), time.Now().Unix()%1000)
|
||
}
|
||
|
||
// getOrCreateSession gets existing session or creates new one
|
||
func (s *AIBookkeepingService) getOrCreateSession(sessionID string, userID uint) *AISession {
|
||
s.sessionMutex.Lock()
|
||
defer s.sessionMutex.Unlock()
|
||
|
||
if sessionID != "" {
|
||
if session, ok := s.sessions[sessionID]; ok {
|
||
if time.Now().Before(session.ExpiresAt) {
|
||
return session
|
||
}
|
||
delete(s.sessions, sessionID)
|
||
}
|
||
}
|
||
|
||
// Create new session
|
||
newID := generateSessionID()
|
||
session := &AISession{
|
||
ID: newID,
|
||
UserID: userID,
|
||
Params: &AITransactionParams{},
|
||
Messages: []ChatMessage{},
|
||
CreatedAt: time.Now(),
|
||
ExpiresAt: time.Now().Add(s.config.AISessionTimeout),
|
||
}
|
||
s.sessions[newID] = session
|
||
return session
|
||
}
|
||
|
||
// cleanupExpiredSessions periodically removes expired sessions
|
||
func (s *AIBookkeepingService) cleanupExpiredSessions() {
|
||
ticker := time.NewTicker(5 * time.Minute)
|
||
for range ticker.C {
|
||
s.sessionMutex.Lock()
|
||
now := time.Now()
|
||
for id, session := range s.sessions {
|
||
if now.After(session.ExpiresAt) {
|
||
delete(s.sessions, id)
|
||
}
|
||
}
|
||
s.sessionMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
// ProcessChat processes a chat message and returns AI response
|
||
// Requirements: 7.2-7.4, 7.7-7.10, 12.5, 12.8
|
||
func (s *AIBookkeepingService) ProcessChat(ctx context.Context, userID uint, sessionID string, message string) (*AIChatResponse, error) {
|
||
session := s.getOrCreateSession(sessionID, userID)
|
||
|
||
// Add user message to history
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "user",
|
||
Content: message,
|
||
})
|
||
|
||
// Parse intent
|
||
params, responseMsg, err := s.llmService.ParseIntent(ctx, message, session.Messages[:len(session.Messages)-1])
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse intent: %w", err)
|
||
}
|
||
|
||
// Merge with existing session params
|
||
if params != nil {
|
||
s.mergeParams(session.Params, params)
|
||
}
|
||
|
||
// Map account and category names to IDs
|
||
if session.Params.Account != "" && session.Params.AccountID == nil {
|
||
accountID, accountName, _ := s.llmService.MapAccountName(ctx, session.Params.Account, userID)
|
||
if accountID != nil {
|
||
session.Params.AccountID = accountID
|
||
session.Params.Account = accountName
|
||
}
|
||
}
|
||
|
||
if session.Params.Category != "" && session.Params.CategoryID == nil {
|
||
categoryID, categoryName, _ := s.llmService.MapCategoryName(ctx, session.Params.Category, session.Params.Type, userID)
|
||
if categoryID != nil {
|
||
session.Params.CategoryID = categoryID
|
||
session.Params.Category = categoryName
|
||
}
|
||
}
|
||
|
||
// If category still not mapped, try to get a default category
|
||
if session.Params.CategoryID == nil && session.Params.Category != "" {
|
||
defaultCategoryID, defaultCategoryName := s.getDefaultCategory(userID, session.Params.Type)
|
||
if defaultCategoryID != nil {
|
||
session.Params.CategoryID = defaultCategoryID
|
||
// Keep the original category name from AI, just set the ID
|
||
if session.Params.Category == "" {
|
||
session.Params.Category = defaultCategoryName
|
||
}
|
||
}
|
||
}
|
||
|
||
// If no account specified, use default account
|
||
if session.Params.AccountID == nil {
|
||
defaultAccountID, defaultAccountName := s.getDefaultAccount(userID, session.Params.Type)
|
||
if defaultAccountID != nil {
|
||
session.Params.AccountID = defaultAccountID
|
||
session.Params.Account = defaultAccountName
|
||
}
|
||
}
|
||
|
||
// Check if we have all required params
|
||
response := &AIChatResponse{
|
||
SessionID: session.ID,
|
||
Message: responseMsg,
|
||
Intent: "create_transaction",
|
||
Params: session.Params,
|
||
}
|
||
|
||
// Check what's missing
|
||
missingFields := s.getMissingFields(session.Params)
|
||
if len(missingFields) > 0 {
|
||
response.NeedsFollowUp = true
|
||
response.FollowUpQuestion = s.generateFollowUpQuestion(missingFields)
|
||
if responseMsg == "" {
|
||
response.Message = response.FollowUpQuestion
|
||
}
|
||
} else {
|
||
// All params complete, generate confirmation card
|
||
card := s.GenerateConfirmationCard(session)
|
||
response.ConfirmationCard = card
|
||
response.Message = fmt.Sprintf("请确认:%s %.2f元,分类<EFBFBD>?s,账户:%s",
|
||
s.getTypeLabel(session.Params.Type),
|
||
*session.Params.Amount,
|
||
session.Params.Category,
|
||
session.Params.Account)
|
||
}
|
||
|
||
// Add assistant response to history
|
||
session.Messages = append(session.Messages, ChatMessage{
|
||
Role: "assistant",
|
||
Content: response.Message,
|
||
})
|
||
|
||
return response, nil
|
||
}
|
||
|
||
// mergeParams merges new params into existing params
|
||
func (s *AIBookkeepingService) mergeParams(existing, new *AITransactionParams) {
|
||
if new.Amount != nil {
|
||
existing.Amount = new.Amount
|
||
}
|
||
if new.Category != "" {
|
||
existing.Category = new.Category
|
||
}
|
||
if new.CategoryID != nil {
|
||
existing.CategoryID = new.CategoryID
|
||
}
|
||
if new.Account != "" {
|
||
existing.Account = new.Account
|
||
}
|
||
if new.AccountID != nil {
|
||
existing.AccountID = new.AccountID
|
||
}
|
||
if new.Type != "" {
|
||
existing.Type = new.Type
|
||
}
|
||
if new.Date != "" {
|
||
existing.Date = new.Date
|
||
}
|
||
if new.Note != "" {
|
||
existing.Note = new.Note
|
||
}
|
||
}
|
||
|
||
// getDefaultAccount gets the default account based on transaction type
|
||
// If no default is set, returns the first available account
|
||
func (s *AIBookkeepingService) getDefaultAccount(userID uint, txType string) (*uint, string) {
|
||
// First try to get user's configured default account
|
||
settings, err := s.userSettingsRepo.GetOrCreate(userID)
|
||
if err == nil && settings != nil {
|
||
var accountID *uint
|
||
if txType == "expense" && settings.DefaultExpenseAccountID != nil {
|
||
accountID = settings.DefaultExpenseAccountID
|
||
} else if txType == "income" && settings.DefaultIncomeAccountID != nil {
|
||
accountID = settings.DefaultIncomeAccountID
|
||
}
|
||
|
||
if accountID != nil {
|
||
account, err := s.accountRepo.GetByID(userID, *accountID)
|
||
if err == nil && account != nil {
|
||
return accountID, account.Name
|
||
}
|
||
}
|
||
}
|
||
|
||
// Fallback: get the first available account
|
||
accounts, err := s.accountRepo.GetAll(userID)
|
||
if err != nil || len(accounts) == 0 {
|
||
return nil, ""
|
||
}
|
||
|
||
// Return the first account (usually sorted by sort_order)
|
||
return &accounts[0].ID, accounts[0].Name
|
||
}
|
||
|
||
// getDefaultCategory gets the first category of the given type
|
||
func (s *AIBookkeepingService) getDefaultCategory(userID uint, txType string) (*uint, string) {
|
||
categories, err := s.categoryRepo.GetAll(userID)
|
||
if err != nil || len(categories) == 0 {
|
||
return nil, ""
|
||
}
|
||
|
||
// Find the first category matching the transaction type
|
||
categoryType := "expense"
|
||
if txType == "income" {
|
||
categoryType = "income"
|
||
}
|
||
|
||
for _, cat := range categories {
|
||
if string(cat.Type) == categoryType {
|
||
return &cat.ID, cat.Name
|
||
}
|
||
}
|
||
|
||
// If no matching type found, return the first category
|
||
return &categories[0].ID, categories[0].Name
|
||
}
|
||
|
||
// getMissingFields returns list of missing required fields
|
||
func (s *AIBookkeepingService) getMissingFields(params *AITransactionParams) []string {
|
||
var missing []string
|
||
if params.Amount == nil {
|
||
missing = append(missing, "amount")
|
||
}
|
||
if params.CategoryID == nil && params.Category == "" {
|
||
missing = append(missing, "category")
|
||
}
|
||
if params.AccountID == nil && params.Account == "" {
|
||
missing = append(missing, "account")
|
||
}
|
||
return missing
|
||
}
|
||
|
||
// generateFollowUpQuestion generates a follow-up question for missing fields
|
||
func (s *AIBookkeepingService) generateFollowUpQuestion(missing []string) string {
|
||
if len(missing) == 0 {
|
||
return ""
|
||
}
|
||
|
||
fieldNames := map[string]string{
|
||
"amount": "金额",
|
||
"category": "分类",
|
||
"account": "账户",
|
||
}
|
||
|
||
var names []string
|
||
for _, field := range missing {
|
||
if name, ok := fieldNames[field]; ok {
|
||
names = append(names, name)
|
||
}
|
||
}
|
||
|
||
if len(names) == 1 {
|
||
return fmt.Sprintf("请问%s是多少?", names[0])
|
||
}
|
||
return fmt.Sprintf("请补充以下信息:%s", strings.Join(names, "、"))
|
||
}
|
||
|
||
// getTypeLabel returns Chinese label for transaction type
|
||
func (s *AIBookkeepingService) getTypeLabel(txType string) string {
|
||
if txType == "income" {
|
||
return "收入"
|
||
}
|
||
return "支出"
|
||
}
|
||
|
||
// GenerateConfirmationCard creates a confirmation card from session params
|
||
func (s *AIBookkeepingService) GenerateConfirmationCard(session *AISession) *ConfirmationCard {
|
||
params := session.Params
|
||
|
||
card := &ConfirmationCard{
|
||
SessionID: session.ID,
|
||
Type: params.Type,
|
||
Note: params.Note,
|
||
IsComplete: true,
|
||
}
|
||
|
||
if params.Amount != nil {
|
||
card.Amount = *params.Amount
|
||
}
|
||
if params.CategoryID != nil {
|
||
card.CategoryID = *params.CategoryID
|
||
}
|
||
card.Category = params.Category
|
||
if params.AccountID != nil {
|
||
card.AccountID = *params.AccountID
|
||
}
|
||
card.Account = params.Account
|
||
|
||
// Set date
|
||
if params.Date != "" {
|
||
card.Date = params.Date
|
||
} else {
|
||
card.Date = time.Now().Format("2006-01-02")
|
||
}
|
||
|
||
return card
|
||
}
|
||
|
||
// TranscribeAudio transcribes audio and returns text
|
||
func (s *AIBookkeepingService) TranscribeAudio(ctx context.Context, audioData io.Reader, filename string) (*TranscriptionResult, error) {
|
||
return s.whisperService.TranscribeAudio(ctx, audioData, filename)
|
||
}
|
||
|
||
// ConfirmTransaction creates a transaction from confirmed card
|
||
// Requirements: 7.10
|
||
func (s *AIBookkeepingService) ConfirmTransaction(ctx context.Context, sessionID string, userID uint) (*models.Transaction, error) {
|
||
s.sessionMutex.RLock()
|
||
session, ok := s.sessions[sessionID]
|
||
s.sessionMutex.RUnlock()
|
||
|
||
if !ok {
|
||
return nil, errors.New("session not found or expired")
|
||
}
|
||
|
||
params := session.Params
|
||
|
||
// Validate required fields
|
||
if params.Amount == nil || *params.Amount <= 0 {
|
||
return nil, errors.New("invalid amount")
|
||
}
|
||
if params.CategoryID == nil {
|
||
return nil, errors.New("category not specified")
|
||
}
|
||
if params.AccountID == nil {
|
||
return nil, errors.New("account not specified")
|
||
}
|
||
|
||
// Parse date
|
||
var txDate time.Time
|
||
if params.Date != "" {
|
||
var err error
|
||
txDate, err = time.Parse("2006-01-02", params.Date)
|
||
if err != nil {
|
||
txDate = time.Now()
|
||
}
|
||
} else {
|
||
txDate = time.Now()
|
||
}
|
||
|
||
// Determine transaction type
|
||
txType := models.TransactionTypeExpense
|
||
if params.Type == "income" {
|
||
txType = models.TransactionTypeIncome
|
||
}
|
||
|
||
// Create transaction
|
||
tx := &models.Transaction{
|
||
UserID: userID,
|
||
Amount: *params.Amount,
|
||
Type: txType,
|
||
CategoryID: *params.CategoryID,
|
||
AccountID: *params.AccountID,
|
||
TransactionDate: txDate,
|
||
Note: params.Note,
|
||
Currency: "CNY",
|
||
}
|
||
|
||
// Save transaction
|
||
if err := s.transactionRepo.Create(tx); err != nil {
|
||
return nil, fmt.Errorf("failed to create transaction: %w", err)
|
||
}
|
||
|
||
// Update account balance
|
||
account, err := s.accountRepo.GetByID(userID, *params.AccountID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to find account: %w", err)
|
||
}
|
||
|
||
if txType == models.TransactionTypeExpense {
|
||
account.Balance -= *params.Amount
|
||
} else {
|
||
account.Balance += *params.Amount
|
||
}
|
||
|
||
if err := s.accountRepo.Update(account); err != nil {
|
||
return nil, fmt.Errorf("failed to update account balance: %w", err)
|
||
}
|
||
|
||
// Clean up session
|
||
s.sessionMutex.Lock()
|
||
delete(s.sessions, sessionID)
|
||
s.sessionMutex.Unlock()
|
||
|
||
return tx, nil
|
||
}
|
||
|
||
// GetSession returns session by ID
|
||
func (s *AIBookkeepingService) GetSession(sessionID string) (*AISession, bool) {
|
||
s.sessionMutex.RLock()
|
||
defer s.sessionMutex.RUnlock()
|
||
|
||
session, ok := s.sessions[sessionID]
|
||
if !ok || time.Now().After(session.ExpiresAt) {
|
||
return nil, false
|
||
}
|
||
return session, true
|
||
}
|