package service import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "os" "path/filepath" "time" "gorm.io/gorm" ) // Service layer errors for backup var ( ErrInvalidPassword = errors.New("invalid password: must be at least 8 characters") ErrBackupFailed = errors.New("backup operation failed") ErrInvalidBackupFile = errors.New("invalid backup file") ErrChecksumMismatch = errors.New("checksum verification failed") ErrDecryptionFailed = errors.New("decryption failed") ) // BackupService handles data backup and restore operations type BackupService struct { db *gorm.DB dbPath string // Optional: explicit database path } // NewBackupService creates a new BackupService instance func NewBackupService(db *gorm.DB) *BackupService { return &BackupService{ db: db, dbPath: "", // Will be auto-detected } } // NewBackupServiceWithPath creates a new BackupService instance with explicit database path func NewBackupServiceWithPath(db *gorm.DB, dbPath string) *BackupService { return &BackupService{ db: db, dbPath: dbPath, } } // BackupRequest represents the input for creating a backup type BackupRequest struct { Password string `json:"password" binding:"required,min=8"` FilePath string `json:"file_path,omitempty"` // Optional custom path } // BackupResponse represents the response after creating a backup type BackupResponse struct { FilePath string `json:"file_path"` Checksum string `json:"checksum"` Size int64 `json:"size"` Created string `json:"created"` } // RestoreRequest represents the input for restoring from a backup type RestoreRequest struct { Password string `json:"password" binding:"required"` FilePath string `json:"file_path" binding:"required"` Checksum string `json:"checksum,omitempty"` // Optional checksum verification } // ExportBackup creates an encrypted backup of the database // The backup file format: // - First 32 bytes: SHA256 checksum of encrypted data // - Next 12 bytes: AES-GCM nonce // - Remaining bytes: Encrypted database content func (s *BackupService) ExportBackup(req BackupRequest) (*BackupResponse, error) { // Validate password if len(req.Password) < 8 { return nil, ErrInvalidPassword } // Get database file path dbPath, err := s.getDatabasePath() if err != nil { return nil, fmt.Errorf("failed to get database path: %w", err) } // Read database file dbData, err := os.ReadFile(dbPath) if err != nil { return nil, fmt.Errorf("failed to read database file: %w", err) } // Encrypt the data encryptedData, err := s.encryptData(dbData, req.Password) if err != nil { return nil, fmt.Errorf("failed to encrypt data: %w", err) } // Calculate checksum of encrypted data checksum := s.calculateChecksum(encryptedData) // Prepare backup file with checksum prefix backupData := append([]byte(checksum), encryptedData...) // Determine output file path outputPath := req.FilePath if outputPath == "" { timestamp := time.Now().Format("20060102_150405") outputPath = filepath.Join(".", fmt.Sprintf("backup_%s.enc", timestamp)) } // Write backup file if err := os.WriteFile(outputPath, backupData, 0600); err != nil { return nil, fmt.Errorf("failed to write backup file: %w", err) } // Get file info fileInfo, err := os.Stat(outputPath) if err != nil { return nil, fmt.Errorf("failed to get file info: %w", err) } return &BackupResponse{ FilePath: outputPath, Checksum: checksum, Size: fileInfo.Size(), Created: time.Now().Format(time.RFC3339), }, nil } // ImportBackup restores the database from an encrypted backup file func (s *BackupService) ImportBackup(req RestoreRequest) error { // Read backup file backupData, err := os.ReadFile(req.FilePath) if err != nil { return fmt.Errorf("failed to read backup file: %w", err) } // Backup file must have at least checksum (64 bytes hex = 32 bytes) + nonce (12 bytes) + some data if len(backupData) < 76 { return ErrInvalidBackupFile } // Extract checksum (first 64 bytes as hex string) storedChecksum := string(backupData[:64]) encryptedData := backupData[64:] // Verify checksum calculatedChecksum := s.calculateChecksum(encryptedData) if storedChecksum != calculatedChecksum { return ErrChecksumMismatch } // If checksum provided in request, verify it matches if req.Checksum != "" && req.Checksum != storedChecksum { return ErrChecksumMismatch } // Decrypt the data decryptedData, err := s.decryptData(encryptedData, req.Password) if err != nil { return fmt.Errorf("failed to decrypt data: %w", err) } // Get database file path dbPath, err := s.getDatabasePath() if err != nil { return fmt.Errorf("failed to get database path: %w", err) } // Close existing database connections sqlDB, err := s.db.DB() if err != nil { return fmt.Errorf("failed to get database connection: %w", err) } if err := sqlDB.Close(); err != nil { return fmt.Errorf("failed to close database connection: %w", err) } // Create backup of current database before overwriting backupPath := dbPath + ".backup" if err := s.copyFile(dbPath, backupPath); err != nil { // If backup fails, log but continue (original file still exists) fmt.Printf("Warning: failed to create backup of current database: %v\n", err) } // Write decrypted data to database file if err := os.WriteFile(dbPath, decryptedData, 0600); err != nil { // Try to restore from backup if write fails if backupErr := s.copyFile(backupPath, dbPath); backupErr != nil { return fmt.Errorf("failed to write database file and restore backup: %w, %w", err, backupErr) } return fmt.Errorf("failed to write database file: %w", err) } // Remove temporary backup os.Remove(backupPath) return nil } // encryptData encrypts data using AES-256-GCM with the provided password func (s *BackupService) encryptData(data []byte, password string) ([]byte, error) { // Derive a 32-byte key from password using SHA256 key := sha256.Sum256([]byte(password)) // Create AES cipher block, err := aes.NewCipher(key[:]) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } // Create GCM mode gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("failed to create GCM: %w", err) } // Generate nonce nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, fmt.Errorf("failed to generate nonce: %w", err) } // Encrypt data (nonce is prepended to ciphertext) ciphertext := gcm.Seal(nonce, nonce, data, nil) return ciphertext, nil } // decryptData decrypts data using AES-256-GCM with the provided password func (s *BackupService) decryptData(encryptedData []byte, password string) ([]byte, error) { // Derive a 32-byte key from password using SHA256 key := sha256.Sum256([]byte(password)) // Create AES cipher block, err := aes.NewCipher(key[:]) if err != nil { return nil, fmt.Errorf("failed to create cipher: %w", err) } // Create GCM mode gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("failed to create GCM: %w", err) } // Check minimum size nonceSize := gcm.NonceSize() if len(encryptedData) < nonceSize { return nil, ErrInvalidBackupFile } // Extract nonce and ciphertext nonce := encryptedData[:nonceSize] ciphertext := encryptedData[nonceSize:] // Decrypt data plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, ErrDecryptionFailed } return plaintext, nil } // calculateChecksum calculates SHA256 checksum of data and returns hex string func (s *BackupService) calculateChecksum(data []byte) string { hash := sha256.Sum256(data) return hex.EncodeToString(hash[:]) } // getDatabasePath returns the path to the SQLite database file func (s *BackupService) getDatabasePath() (string, error) { // If explicit path is set, use it if s.dbPath != "" { if _, err := os.Stat(s.dbPath); err != nil { return "", fmt.Errorf("database file not found at %s: %w", s.dbPath, err) } return s.dbPath, nil } // Get the database connection sqlDB, err := s.db.DB() if err != nil { return "", err } // Ensure we have a valid connection if err := sqlDB.Ping(); err != nil { return "", fmt.Errorf("database connection not available: %w", err) } // For SQLite, try common database paths possiblePaths := []string{ "data/accounting.db", "./accounting.db", "../data/accounting.db", } for _, dbPath := range possiblePaths { if _, err := os.Stat(dbPath); err == nil { return dbPath, nil } } // If no standard path found, return error return "", fmt.Errorf("database file not found in standard locations") } // copyFile copies a file from src to dst func (s *BackupService) copyFile(src, dst string) error { sourceData, err := os.ReadFile(src) if err != nil { return err } return os.WriteFile(dst, sourceData, 0600) } // VerifyBackup verifies the integrity of a backup file without restoring it func (s *BackupService) VerifyBackup(filePath string) (bool, string, error) { // Read backup file backupData, err := os.ReadFile(filePath) if err != nil { return false, "", fmt.Errorf("failed to read backup file: %w", err) } // Backup file must have at least checksum + nonce + some data if len(backupData) < 76 { return false, "", ErrInvalidBackupFile } // Extract checksum storedChecksum := string(backupData[:64]) encryptedData := backupData[64:] // Verify checksum calculatedChecksum := s.calculateChecksum(encryptedData) isValid := storedChecksum == calculatedChecksum return isValid, storedChecksum, nil }