Files
Novault-backend/internal/service/backup_service.go
2026-01-25 21:59:00 +08:00

344 lines
9.4 KiB
Go

package service
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"time"
"gorm.io/gorm"
)
// Service layer errors for backup
var (
ErrInvalidPassword = errors.New("invalid password: must be at least 8 characters")
ErrBackupFailed = errors.New("backup operation failed")
ErrInvalidBackupFile = errors.New("invalid backup file")
ErrChecksumMismatch = errors.New("checksum verification failed")
ErrDecryptionFailed = errors.New("decryption failed")
)
// BackupService handles data backup and restore operations
type BackupService struct {
db *gorm.DB
dbPath string // Optional: explicit database path
}
// NewBackupService creates a new BackupService instance
func NewBackupService(db *gorm.DB) *BackupService {
return &BackupService{
db: db,
dbPath: "", // Will be auto-detected
}
}
// NewBackupServiceWithPath creates a new BackupService instance with explicit database path
func NewBackupServiceWithPath(db *gorm.DB, dbPath string) *BackupService {
return &BackupService{
db: db,
dbPath: dbPath,
}
}
// BackupRequest represents the input for creating a backup
type BackupRequest struct {
Password string `json:"password" binding:"required,min=8"`
FilePath string `json:"file_path,omitempty"` // Optional custom path
}
// BackupResponse represents the response after creating a backup
type BackupResponse struct {
FilePath string `json:"file_path"`
Checksum string `json:"checksum"`
Size int64 `json:"size"`
Created string `json:"created"`
}
// RestoreRequest represents the input for restoring from a backup
type RestoreRequest struct {
Password string `json:"password" binding:"required"`
FilePath string `json:"file_path" binding:"required"`
Checksum string `json:"checksum,omitempty"` // Optional checksum verification
}
// ExportBackup creates an encrypted backup of the database
// The backup file format:
// - First 32 bytes: SHA256 checksum of encrypted data
// - Next 12 bytes: AES-GCM nonce
// - Remaining bytes: Encrypted database content
func (s *BackupService) ExportBackup(req BackupRequest) (*BackupResponse, error) {
// Validate password
if len(req.Password) < 8 {
return nil, ErrInvalidPassword
}
// Get database file path
dbPath, err := s.getDatabasePath()
if err != nil {
return nil, fmt.Errorf("failed to get database path: %w", err)
}
// Read database file
dbData, err := os.ReadFile(dbPath)
if err != nil {
return nil, fmt.Errorf("failed to read database file: %w", err)
}
// Encrypt the data
encryptedData, err := s.encryptData(dbData, req.Password)
if err != nil {
return nil, fmt.Errorf("failed to encrypt data: %w", err)
}
// Calculate checksum of encrypted data
checksum := s.calculateChecksum(encryptedData)
// Prepare backup file with checksum prefix
backupData := append([]byte(checksum), encryptedData...)
// Determine output file path
outputPath := req.FilePath
if outputPath == "" {
timestamp := time.Now().Format("20060102_150405")
outputPath = filepath.Join(".", fmt.Sprintf("backup_%s.enc", timestamp))
}
// Write backup file
if err := os.WriteFile(outputPath, backupData, 0600); err != nil {
return nil, fmt.Errorf("failed to write backup file: %w", err)
}
// Get file info
fileInfo, err := os.Stat(outputPath)
if err != nil {
return nil, fmt.Errorf("failed to get file info: %w", err)
}
return &BackupResponse{
FilePath: outputPath,
Checksum: checksum,
Size: fileInfo.Size(),
Created: time.Now().Format(time.RFC3339),
}, nil
}
// ImportBackup restores the database from an encrypted backup file
func (s *BackupService) ImportBackup(req RestoreRequest) error {
// Read backup file
backupData, err := os.ReadFile(req.FilePath)
if err != nil {
return fmt.Errorf("failed to read backup file: %w", err)
}
// Backup file must have at least checksum (64 bytes hex = 32 bytes) + nonce (12 bytes) + some data
if len(backupData) < 76 {
return ErrInvalidBackupFile
}
// Extract checksum (first 64 bytes as hex string)
storedChecksum := string(backupData[:64])
encryptedData := backupData[64:]
// Verify checksum
calculatedChecksum := s.calculateChecksum(encryptedData)
if storedChecksum != calculatedChecksum {
return ErrChecksumMismatch
}
// If checksum provided in request, verify it matches
if req.Checksum != "" && req.Checksum != storedChecksum {
return ErrChecksumMismatch
}
// Decrypt the data
decryptedData, err := s.decryptData(encryptedData, req.Password)
if err != nil {
return fmt.Errorf("failed to decrypt data: %w", err)
}
// Get database file path
dbPath, err := s.getDatabasePath()
if err != nil {
return fmt.Errorf("failed to get database path: %w", err)
}
// Close existing database connections
sqlDB, err := s.db.DB()
if err != nil {
return fmt.Errorf("failed to get database connection: %w", err)
}
if err := sqlDB.Close(); err != nil {
return fmt.Errorf("failed to close database connection: %w", err)
}
// Create backup of current database before overwriting
backupPath := dbPath + ".backup"
if err := s.copyFile(dbPath, backupPath); err != nil {
// If backup fails, log but continue (original file still exists)
fmt.Printf("Warning: failed to create backup of current database: %v\n", err)
}
// Write decrypted data to database file
if err := os.WriteFile(dbPath, decryptedData, 0600); err != nil {
// Try to restore from backup if write fails
if backupErr := s.copyFile(backupPath, dbPath); backupErr != nil {
return fmt.Errorf("failed to write database file and restore backup: %w, %w", err, backupErr)
}
return fmt.Errorf("failed to write database file: %w", err)
}
// Remove temporary backup
os.Remove(backupPath)
return nil
}
// encryptData encrypts data using AES-256-GCM with the provided password
func (s *BackupService) encryptData(data []byte, password string) ([]byte, error) {
// Derive a 32-byte key from password using SHA256
key := sha256.Sum256([]byte(password))
// Create AES cipher
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Generate nonce
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt data (nonce is prepended to ciphertext)
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}
// decryptData decrypts data using AES-256-GCM with the provided password
func (s *BackupService) decryptData(encryptedData []byte, password string) ([]byte, error) {
// Derive a 32-byte key from password using SHA256
key := sha256.Sum256([]byte(password))
// Create AES cipher
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Check minimum size
nonceSize := gcm.NonceSize()
if len(encryptedData) < nonceSize {
return nil, ErrInvalidBackupFile
}
// Extract nonce and ciphertext
nonce := encryptedData[:nonceSize]
ciphertext := encryptedData[nonceSize:]
// Decrypt data
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, ErrDecryptionFailed
}
return plaintext, nil
}
// calculateChecksum calculates SHA256 checksum of data and returns hex string
func (s *BackupService) calculateChecksum(data []byte) string {
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:])
}
// getDatabasePath returns the path to the SQLite database file
func (s *BackupService) getDatabasePath() (string, error) {
// If explicit path is set, use it
if s.dbPath != "" {
if _, err := os.Stat(s.dbPath); err != nil {
return "", fmt.Errorf("database file not found at %s: %w", s.dbPath, err)
}
return s.dbPath, nil
}
// Get the database connection
sqlDB, err := s.db.DB()
if err != nil {
return "", err
}
// Ensure we have a valid connection
if err := sqlDB.Ping(); err != nil {
return "", fmt.Errorf("database connection not available: %w", err)
}
// For SQLite, try common database paths
possiblePaths := []string{
"data/accounting.db",
"./accounting.db",
"../data/accounting.db",
}
for _, dbPath := range possiblePaths {
if _, err := os.Stat(dbPath); err == nil {
return dbPath, nil
}
}
// If no standard path found, return error
return "", fmt.Errorf("database file not found in standard locations")
}
// copyFile copies a file from src to dst
func (s *BackupService) copyFile(src, dst string) error {
sourceData, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, sourceData, 0600)
}
// VerifyBackup verifies the integrity of a backup file without restoring it
func (s *BackupService) VerifyBackup(filePath string) (bool, string, error) {
// Read backup file
backupData, err := os.ReadFile(filePath)
if err != nil {
return false, "", fmt.Errorf("failed to read backup file: %w", err)
}
// Backup file must have at least checksum + nonce + some data
if len(backupData) < 76 {
return false, "", ErrInvalidBackupFile
}
// Extract checksum
storedChecksum := string(backupData[:64])
encryptedData := backupData[64:]
// Verify checksum
calculatedChecksum := s.calculateChecksum(encryptedData)
isValid := storedChecksum == calculatedChecksum
return isValid, storedChecksum, nil
}