Skip to content

Commit

Permalink
Merge pull request #2 from sd416/1.1
Browse files Browse the repository at this point in the history
Major speed inprovements in encyption and decryption
  • Loading branch information
sd416 authored Jul 29, 2024
2 parents 7d7a46b + 4ba5769 commit 35d0866
Showing 1 changed file with 149 additions and 54 deletions.
203 changes: 149 additions & 54 deletions pkg/fileops/fileops.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fileops

import (
"bufio"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
Expand All @@ -10,80 +11,123 @@ import (
"fmt"
"io"
"os"
"runtime"
"sync"
"time"
)

const chunkSize = 64 * 1024 // 64KB chunks

func EncryptFile(inputFile, outputFile string, encryptor crypto.Encryptor, logger *logging.Logger) error {
logger.Log(fmt.Sprintf("Starting encryption of file: %s", inputFile))
startTime := time.Now()

// Read the input file
plaintext, err := os.ReadFile(inputFile)
inFile, err := os.Open(inputFile)
if err != nil {
return fmt.Errorf("error reading input file: %v", err)
return fmt.Errorf("error opening input file: %v", err)
}
defer inFile.Close()

outFile, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("error creating output file: %v", err)
}
logger.Log(fmt.Sprintf("Read %d bytes from input file", len(plaintext)))
defer outFile.Close()

bufReader := bufio.NewReaderSize(inFile, chunkSize)
bufWriter := bufio.NewWriterSize(outFile, chunkSize)
defer bufWriter.Flush()

// Generate a random AES key
aesKey := make([]byte, 32) // AES-256
if _, err := rand.Read(aesKey); err != nil {
if _, err := io.ReadFull(rand.Reader, aesKey); err != nil {
return fmt.Errorf("error generating AES key: %v", err)
}
logger.Log("Generated AES key")

// Encrypt the AES key
encryptedAESKey, err := encryptor.EncryptKey(aesKey)
if err != nil {
return fmt.Errorf("error encrypting AES key: %v", err)
}
logger.Log(fmt.Sprintf("Encrypted AES key (length: %d bytes)", len(encryptedAESKey)))

// Create and open the output file
outFile, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("error creating output file: %v", err)
}
defer outFile.Close()

// Write the length of the encrypted AES key
if err := binary.Write(outFile, binary.BigEndian, uint32(len(encryptedAESKey))); err != nil {
if err := binary.Write(bufWriter, binary.BigEndian, uint32(len(encryptedAESKey))); err != nil {
return fmt.Errorf("error writing key length: %v", err)
}

// Write the encrypted AES key
if _, err := outFile.Write(encryptedAESKey); err != nil {
if _, err := bufWriter.Write(encryptedAESKey); err != nil {
return fmt.Errorf("error writing encrypted key: %v", err)
}
logger.Log("Wrote encrypted AES key to file")

// Create AES cipher
block, err := aes.NewCipher(aesKey)
if err != nil {
return fmt.Errorf("error creating AES cipher: %v", err)
}

// Create a random IV
iv := make([]byte, aes.BlockSize)
if _, err := rand.Read(iv); err != nil {
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return fmt.Errorf("error generating IV: %v", err)
}

// Write the IV
if _, err := outFile.Write(iv); err != nil {
if _, err := bufWriter.Write(iv); err != nil {
return fmt.Errorf("error writing IV: %v", err)
}
logger.Log("Wrote IV to file")

// Create the AES CTR stream
stream := cipher.NewCTR(block, iv)
numWorkers := runtime.NumCPU()
jobs := make(chan []byte, numWorkers)
results := make(chan []byte, numWorkers)
errors := make(chan error, 1)

// Encrypt and write the data
encryptedData := make([]byte, len(plaintext))
stream.XORKeyStream(encryptedData, plaintext)
if _, err := outFile.Write(encryptedData); err != nil {
return fmt.Errorf("error writing encrypted data: %v", err)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(workerIV []byte) {
defer wg.Done()
stream := cipher.NewCTR(block, workerIV)
for chunk := range jobs {
encryptedChunk := make([]byte, len(chunk))
stream.XORKeyStream(encryptedChunk, chunk)
results <- encryptedChunk
}
}(incrementIV(iv))
}

go func() {
defer close(jobs)
for {
chunk := make([]byte, chunkSize)
n, err := bufReader.Read(chunk)
if err != nil {
if err != io.EOF {
errors <- fmt.Errorf("error reading input file: %v", err)
}
break
}
if n == 0 {
break
}
jobs <- chunk[:n]
}
}()

go func() {
defer close(results)
for encryptedChunk := range results {
if _, err := bufWriter.Write(encryptedChunk); err != nil {
errors <- fmt.Errorf("error writing encrypted data: %v", err)
return
}
}
}()

wg.Wait()

select {
case err := <-errors:
return err
default:
}
logger.Log(fmt.Sprintf("Wrote %d bytes of encrypted data", len(encryptedData)))

duration := time.Since(startTime)
logger.Log(fmt.Sprintf("Encryption completed. Duration: %v", duration))
Expand All @@ -94,28 +138,34 @@ func DecryptFile(inputFile, outputFile string, decryptor crypto.Decryptor, logge
logger.Log(fmt.Sprintf("Starting decryption of file: %s", inputFile))
startTime := time.Now()

// Open the input file
inFile, err := os.Open(inputFile)
if err != nil {
return fmt.Errorf("error opening input file: %v", err)
}
defer inFile.Close()

// Read the length of the encrypted AES key
outFile, err := os.Create(outputFile)
if err != nil {
return fmt.Errorf("error creating output file: %v", err)
}
defer outFile.Close()

bufReader := bufio.NewReaderSize(inFile, chunkSize)
bufWriter := bufio.NewWriterSize(outFile, chunkSize)
defer bufWriter.Flush()

var keyLength uint32
if err := binary.Read(inFile, binary.BigEndian, &keyLength); err != nil {
if err := binary.Read(bufReader, binary.BigEndian, &keyLength); err != nil {
return fmt.Errorf("error reading key length: %v", err)
}
logger.Log(fmt.Sprintf("Read encrypted AES key length: %d bytes", keyLength))

// Read the encrypted AES key
encryptedAESKey := make([]byte, keyLength)
if _, err := io.ReadFull(inFile, encryptedAESKey); err != nil {
if _, err := io.ReadFull(bufReader, encryptedAESKey); err != nil {
return fmt.Errorf("error reading encrypted key: %v", err)
}
logger.Log(fmt.Sprintf("Read encrypted AES key (%d bytes)", len(encryptedAESKey)))

// Decrypt the AES key
aesKey, err := decryptor.DecryptKey(encryptedAESKey)
if err != nil {
return fmt.Errorf("error decrypting AES key: %v", err)
Expand All @@ -125,40 +175,85 @@ func DecryptFile(inputFile, outputFile string, decryptor crypto.Decryptor, logge
}
logger.Log("Successfully decrypted AES key")

// Create AES cipher
block, err := aes.NewCipher(aesKey)
if err != nil {
return fmt.Errorf("error creating AES cipher: %v", err)
}

// Read the IV
iv := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(inFile, iv); err != nil {
if _, err := io.ReadFull(bufReader, iv); err != nil {
return fmt.Errorf("error reading IV: %v", err)
}
logger.Log("Read IV from file")

// Create the AES CTR stream
stream := cipher.NewCTR(block, iv)
numWorkers := runtime.NumCPU()
jobs := make(chan []byte, numWorkers)
results := make(chan []byte, numWorkers)
errors := make(chan error, 1)

// Read the encrypted data
encryptedData, err := io.ReadAll(inFile)
if err != nil {
return fmt.Errorf("error reading encrypted data: %v", err)
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(workerIV []byte) {
defer wg.Done()
stream := cipher.NewCTR(block, workerIV)
for chunk := range jobs {
decryptedChunk := make([]byte, len(chunk))
stream.XORKeyStream(decryptedChunk, chunk)
results <- decryptedChunk
}
}(incrementIV(iv))
}
logger.Log(fmt.Sprintf("Read %d bytes of encrypted data", len(encryptedData)))

// Decrypt the data
decryptedData := make([]byte, len(encryptedData))
stream.XORKeyStream(decryptedData, encryptedData)
go func() {
defer close(jobs)
for {
chunk := make([]byte, chunkSize)
n, err := bufReader.Read(chunk)
if err != nil {
if err != io.EOF {
errors <- fmt.Errorf("error reading encrypted data: %v", err)
}
break
}
if n == 0 {
break
}
jobs <- chunk[:n]
}
}()

go func() {
defer close(results)
for decryptedChunk := range results {
if _, err := bufWriter.Write(decryptedChunk); err != nil {
errors <- fmt.Errorf("error writing decrypted data: %v", err)
return
}
}
}()

// Write the decrypted data to the output file
if err := os.WriteFile(outputFile, decryptedData, 0644); err != nil {
return fmt.Errorf("error writing decrypted data: %v", err)
wg.Wait()

select {
case err := <-errors:
return err
default:
}
logger.Log(fmt.Sprintf("Wrote %d bytes of decrypted data", len(decryptedData)))

duration := time.Since(startTime)
logger.Log(fmt.Sprintf("Decryption completed. Duration: %v", duration))
return nil
}

func incrementIV(iv []byte) []byte {
newIV := make([]byte, len(iv))
copy(newIV, iv)
for i := len(newIV) - 1; i >= 0; i-- {
newIV[i]++
if newIV[i] != 0 {
break
}
}
return newIV
}

0 comments on commit 35d0866

Please sign in to comment.