344 lines
9.4 KiB
Go
344 lines
9.4 KiB
Go
package service
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Service layer errors for backup
|
|
var (
|
|
ErrInvalidPassword = errors.New("invalid password: must be at least 8 characters")
|
|
ErrBackupFailed = errors.New("backup operation failed")
|
|
ErrInvalidBackupFile = errors.New("invalid backup file")
|
|
ErrChecksumMismatch = errors.New("checksum verification failed")
|
|
ErrDecryptionFailed = errors.New("decryption failed")
|
|
)
|
|
|
|
// BackupService handles data backup and restore operations
|
|
type BackupService struct {
|
|
db *gorm.DB
|
|
dbPath string // Optional: explicit database path
|
|
}
|
|
|
|
// NewBackupService creates a new BackupService instance
|
|
func NewBackupService(db *gorm.DB) *BackupService {
|
|
return &BackupService{
|
|
db: db,
|
|
dbPath: "", // Will be auto-detected
|
|
}
|
|
}
|
|
|
|
// NewBackupServiceWithPath creates a new BackupService instance with explicit database path
|
|
func NewBackupServiceWithPath(db *gorm.DB, dbPath string) *BackupService {
|
|
return &BackupService{
|
|
db: db,
|
|
dbPath: dbPath,
|
|
}
|
|
}
|
|
|
|
// BackupRequest represents the input for creating a backup
|
|
type BackupRequest struct {
|
|
Password string `json:"password" binding:"required,min=8"`
|
|
FilePath string `json:"file_path,omitempty"` // Optional custom path
|
|
}
|
|
|
|
// BackupResponse represents the response after creating a backup
|
|
type BackupResponse struct {
|
|
FilePath string `json:"file_path"`
|
|
Checksum string `json:"checksum"`
|
|
Size int64 `json:"size"`
|
|
Created string `json:"created"`
|
|
}
|
|
|
|
// RestoreRequest represents the input for restoring from a backup
|
|
type RestoreRequest struct {
|
|
Password string `json:"password" binding:"required"`
|
|
FilePath string `json:"file_path" binding:"required"`
|
|
Checksum string `json:"checksum,omitempty"` // Optional checksum verification
|
|
}
|
|
|
|
// ExportBackup creates an encrypted backup of the database
|
|
// The backup file format:
|
|
// - First 32 bytes: SHA256 checksum of encrypted data
|
|
// - Next 12 bytes: AES-GCM nonce
|
|
// - Remaining bytes: Encrypted database content
|
|
func (s *BackupService) ExportBackup(req BackupRequest) (*BackupResponse, error) {
|
|
// Validate password
|
|
if len(req.Password) < 8 {
|
|
return nil, ErrInvalidPassword
|
|
}
|
|
|
|
// Get database file path
|
|
dbPath, err := s.getDatabasePath()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get database path: %w", err)
|
|
}
|
|
|
|
// Read database file
|
|
dbData, err := os.ReadFile(dbPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read database file: %w", err)
|
|
}
|
|
|
|
// Encrypt the data
|
|
encryptedData, err := s.encryptData(dbData, req.Password)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encrypt data: %w", err)
|
|
}
|
|
|
|
// Calculate checksum of encrypted data
|
|
checksum := s.calculateChecksum(encryptedData)
|
|
|
|
// Prepare backup file with checksum prefix
|
|
backupData := append([]byte(checksum), encryptedData...)
|
|
|
|
// Determine output file path
|
|
outputPath := req.FilePath
|
|
if outputPath == "" {
|
|
timestamp := time.Now().Format("20060102_150405")
|
|
outputPath = filepath.Join(".", fmt.Sprintf("backup_%s.enc", timestamp))
|
|
}
|
|
|
|
// Write backup file
|
|
if err := os.WriteFile(outputPath, backupData, 0600); err != nil {
|
|
return nil, fmt.Errorf("failed to write backup file: %w", err)
|
|
}
|
|
|
|
// Get file info
|
|
fileInfo, err := os.Stat(outputPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get file info: %w", err)
|
|
}
|
|
|
|
return &BackupResponse{
|
|
FilePath: outputPath,
|
|
Checksum: checksum,
|
|
Size: fileInfo.Size(),
|
|
Created: time.Now().Format(time.RFC3339),
|
|
}, nil
|
|
}
|
|
|
|
// ImportBackup restores the database from an encrypted backup file
|
|
func (s *BackupService) ImportBackup(req RestoreRequest) error {
|
|
// Read backup file
|
|
backupData, err := os.ReadFile(req.FilePath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read backup file: %w", err)
|
|
}
|
|
|
|
// Backup file must have at least checksum (64 bytes hex = 32 bytes) + nonce (12 bytes) + some data
|
|
if len(backupData) < 76 {
|
|
return ErrInvalidBackupFile
|
|
}
|
|
|
|
// Extract checksum (first 64 bytes as hex string)
|
|
storedChecksum := string(backupData[:64])
|
|
encryptedData := backupData[64:]
|
|
|
|
// Verify checksum
|
|
calculatedChecksum := s.calculateChecksum(encryptedData)
|
|
if storedChecksum != calculatedChecksum {
|
|
return ErrChecksumMismatch
|
|
}
|
|
|
|
// If checksum provided in request, verify it matches
|
|
if req.Checksum != "" && req.Checksum != storedChecksum {
|
|
return ErrChecksumMismatch
|
|
}
|
|
|
|
// Decrypt the data
|
|
decryptedData, err := s.decryptData(encryptedData, req.Password)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decrypt data: %w", err)
|
|
}
|
|
|
|
// Get database file path
|
|
dbPath, err := s.getDatabasePath()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get database path: %w", err)
|
|
}
|
|
|
|
// Close existing database connections
|
|
sqlDB, err := s.db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get database connection: %w", err)
|
|
}
|
|
if err := sqlDB.Close(); err != nil {
|
|
return fmt.Errorf("failed to close database connection: %w", err)
|
|
}
|
|
|
|
// Create backup of current database before overwriting
|
|
backupPath := dbPath + ".backup"
|
|
if err := s.copyFile(dbPath, backupPath); err != nil {
|
|
// If backup fails, log but continue (original file still exists)
|
|
fmt.Printf("Warning: failed to create backup of current database: %v\n", err)
|
|
}
|
|
|
|
// Write decrypted data to database file
|
|
if err := os.WriteFile(dbPath, decryptedData, 0600); err != nil {
|
|
// Try to restore from backup if write fails
|
|
if backupErr := s.copyFile(backupPath, dbPath); backupErr != nil {
|
|
return fmt.Errorf("failed to write database file and restore backup: %w, %w", err, backupErr)
|
|
}
|
|
return fmt.Errorf("failed to write database file: %w", err)
|
|
}
|
|
|
|
// Remove temporary backup
|
|
os.Remove(backupPath)
|
|
|
|
return nil
|
|
}
|
|
|
|
// encryptData encrypts data using AES-256-GCM with the provided password
|
|
func (s *BackupService) encryptData(data []byte, password string) ([]byte, error) {
|
|
// Derive a 32-byte key from password using SHA256
|
|
key := sha256.Sum256([]byte(password))
|
|
|
|
// Create AES cipher
|
|
block, err := aes.NewCipher(key[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
|
}
|
|
|
|
// Create GCM mode
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
// Generate nonce
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
|
|
// Encrypt data (nonce is prepended to ciphertext)
|
|
ciphertext := gcm.Seal(nonce, nonce, data, nil)
|
|
|
|
return ciphertext, nil
|
|
}
|
|
|
|
// decryptData decrypts data using AES-256-GCM with the provided password
|
|
func (s *BackupService) decryptData(encryptedData []byte, password string) ([]byte, error) {
|
|
// Derive a 32-byte key from password using SHA256
|
|
key := sha256.Sum256([]byte(password))
|
|
|
|
// Create AES cipher
|
|
block, err := aes.NewCipher(key[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
|
}
|
|
|
|
// Create GCM mode
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
// Check minimum size
|
|
nonceSize := gcm.NonceSize()
|
|
if len(encryptedData) < nonceSize {
|
|
return nil, ErrInvalidBackupFile
|
|
}
|
|
|
|
// Extract nonce and ciphertext
|
|
nonce := encryptedData[:nonceSize]
|
|
ciphertext := encryptedData[nonceSize:]
|
|
|
|
// Decrypt data
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return nil, ErrDecryptionFailed
|
|
}
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
// calculateChecksum calculates SHA256 checksum of data and returns hex string
|
|
func (s *BackupService) calculateChecksum(data []byte) string {
|
|
hash := sha256.Sum256(data)
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// getDatabasePath returns the path to the SQLite database file
|
|
func (s *BackupService) getDatabasePath() (string, error) {
|
|
// If explicit path is set, use it
|
|
if s.dbPath != "" {
|
|
if _, err := os.Stat(s.dbPath); err != nil {
|
|
return "", fmt.Errorf("database file not found at %s: %w", s.dbPath, err)
|
|
}
|
|
return s.dbPath, nil
|
|
}
|
|
|
|
// Get the database connection
|
|
sqlDB, err := s.db.DB()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Ensure we have a valid connection
|
|
if err := sqlDB.Ping(); err != nil {
|
|
return "", fmt.Errorf("database connection not available: %w", err)
|
|
}
|
|
|
|
// For SQLite, try common database paths
|
|
possiblePaths := []string{
|
|
"data/accounting.db",
|
|
"./accounting.db",
|
|
"../data/accounting.db",
|
|
}
|
|
|
|
for _, dbPath := range possiblePaths {
|
|
if _, err := os.Stat(dbPath); err == nil {
|
|
return dbPath, nil
|
|
}
|
|
}
|
|
|
|
// If no standard path found, return error
|
|
return "", fmt.Errorf("database file not found in standard locations")
|
|
}
|
|
|
|
// copyFile copies a file from src to dst
|
|
func (s *BackupService) copyFile(src, dst string) error {
|
|
sourceData, err := os.ReadFile(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(dst, sourceData, 0600)
|
|
}
|
|
|
|
// VerifyBackup verifies the integrity of a backup file without restoring it
|
|
func (s *BackupService) VerifyBackup(filePath string) (bool, string, error) {
|
|
// Read backup file
|
|
backupData, err := os.ReadFile(filePath)
|
|
if err != nil {
|
|
return false, "", fmt.Errorf("failed to read backup file: %w", err)
|
|
}
|
|
|
|
// Backup file must have at least checksum + nonce + some data
|
|
if len(backupData) < 76 {
|
|
return false, "", ErrInvalidBackupFile
|
|
}
|
|
|
|
// Extract checksum
|
|
storedChecksum := string(backupData[:64])
|
|
encryptedData := backupData[64:]
|
|
|
|
// Verify checksum
|
|
calculatedChecksum := s.calculateChecksum(encryptedData)
|
|
isValid := storedChecksum == calculatedChecksum
|
|
|
|
return isValid, storedChecksum, nil
|
|
}
|