Skip to content

Commit

Permalink
feat:encrypt secret fields of entry before store to db(#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
vimiix committed Jan 9, 2024
1 parent b0fe228 commit b6a9005
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 7 deletions.
164 changes: 164 additions & 0 deletions internal/encrypt/encrypt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright 2022 Enmotech Inc. All rights reserved.

package encrypt

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"strings"
"time"

"github.com/pkg/errors"

"github.com/vimiix/ssx/internal/lg"
"github.com/vimiix/ssx/internal/utils"
)

// Encrypt Generates the ciphertext for the given string.
// If the encryption fails, the original characters will be returned.
// If the passed string is empty, return empty directly.
func Encrypt(text string) string {
if text == "" {
return ""
}

curTime := time.Now().Format("01021504")
salt := md5encode(curTime)
key := salt[:8] + curTime

cipherText, err := aesEncrypt(text, key)
if err != nil {
lg.Debug("failed to encrypt text '%s': %s", utils.MaskString(text), err)
return text
}
return base64.StdEncoding.EncodeToString([]byte(salt[:8] + shiftEncode(curTime) + cipherText))
}

func Decrypt(rawCipher string) string {
if rawCipher == "" {
return ""
}

dec, err := base64.StdEncoding.DecodeString(rawCipher)
if err != nil {
lg.Debug("failed to base64 decode cipher text '%s': %s", rawCipher, err)
return rawCipher
}

key := string(dec[:8]) + shiftDecode(string(dec[8:16]))
text := string(dec[16:])
res, err := aesDecrypt(text, key)
if err != nil {
lg.Debug("failed to decypt cipher '%s': %s", text, err)
return rawCipher
}
return res
}

func md5encode(s string) string {
h := md5.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}

func shiftEncode(s string) string {
rs := make([]string, 0, len(s))
for _, c := range s[:] {
// start with '<'
rs = append(rs, fmt.Sprintf("%c", c+12))
}
return strings.Join(rs, "")
}

func shiftDecode(s string) string {
rs := make([]string, 0, len(s))
for _, c := range s[:] {
rs = append(rs, fmt.Sprintf("%c", c-12))
}
return strings.Join(rs, "")
}

func addBase64Padding(value string) string {
m := len(value) % 4
if m != 0 {
value += strings.Repeat("=", 4-m)
}

return value
}

func removeBase64Padding(value string) string {
return strings.Replace(value, "=", "", -1)
}

func pad(src []byte) []byte {
padding := aes.BlockSize - len(src)%aes.BlockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}

func unpad(src []byte) ([]byte, error) {
length := len(src)
unpadding := int(src[length-1])

if unpadding > length {
return nil, errors.New("unpad error. This could happen when incorrect encryption key is used")
}

return src[:(length - unpadding)], nil
}

func aesEncrypt(text string, key string) (string, error) {
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", err
}

msg := pad([]byte(text))
ciphertext := make([]byte, aes.BlockSize+len(msg))
iv := ciphertext[:aes.BlockSize]
if _, err = io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}

cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext[aes.BlockSize:], msg)
finalMsg := removeBase64Padding(base64.URLEncoding.EncodeToString(ciphertext))
return finalMsg, nil
}

func aesDecrypt(text string, key string) (string, error) {
block, err := aes.NewCipher([]byte(key))
if err != nil {
return "", err
}

decodedMsg, err := base64.URLEncoding.DecodeString(addBase64Padding(text))
if err != nil {
return "", err
}

if (len(decodedMsg) % aes.BlockSize) != 0 {
return "", errors.New("blocksize must be multiple of decoded message length")
}

iv := decodedMsg[:aes.BlockSize]
msg := decodedMsg[aes.BlockSize:]

cfb := cipher.NewCFBDecrypter(block, iv)
cfb.XORKeyStream(msg, msg)

unpadMsg, err := unpad(msg)
if err != nil {
return "", err
}

return string(unpadMsg), nil
}
42 changes: 42 additions & 0 deletions internal/encrypt/encrypt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package encrypt

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestEncryptDecrypt(t *testing.T) {
tests := []struct {
name string
text string
}{
{"empty", ""},
{"regular", "abc123"},
{"symbol", "!*#$)@>?"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := Decrypt(Encrypt(tt.text))
assert.Equal(t, tt.text, actual)
})
}
}

func TestDecrypt(t *testing.T) {
tests := []struct {
name string
cipher string
expect string
}{
{"empty", "", ""},
{"regular", "NmUxODZmYWM8PTxFPUQ9QENIQUc2eGl4T2pEWnQtQ0I2YkE0RkRxRUI0ei1fLUlNMmZKYi1lTFlnQk0=", "abc123"},
{"plaintext", "abc123", "abc123"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := Decrypt(tt.cipher)
assert.Equal(t, tt.expect, actual)
})
}
}
11 changes: 11 additions & 0 deletions internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,14 @@ func ExpandHomeDir(path string) string {

return filepath.Join(u.HomeDir, path[1:])
}

func MaskString(s string) string {
mask := "***"
if len(s) == 0 {
return s
} else if len(s) <= 3 {
return s[:1] + mask
} else {
return s[:2] + mask + s[len(s)-1:]
}
}
18 changes: 18 additions & 0 deletions internal/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,21 @@ func TestExpandHomeDir(t *testing.T) {
})
}
}

func TestMaskString(t *testing.T) {
tests := []struct {
s string
expect string
}{
{"", ""},
{"a", "a***"},
{"ab", "a***"},
{"abc", "a***"},
{"abcd", "ab***d"},
{"abcdefgh", "ab***h"},
}
for _, tt := range tests {
actual := MaskString(tt.s)
assert.Equal(t, tt.expect, actual)
}
}
31 changes: 24 additions & 7 deletions ssx/bbolt/bbolt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"go.etcd.io/bbolt"

"github.com/vimiix/ssx/internal/encrypt"
"github.com/vimiix/ssx/internal/errmsg"
"github.com/vimiix/ssx/internal/lg"
"github.com/vimiix/ssx/ssx/entry"
Expand Down Expand Up @@ -82,9 +83,9 @@ func (r *Repo) TouchEntry(e *entry.Entry) error {
e.UpdateAt = time.Now()
}
// update
buf, marshalErr := json.Marshal(e)
if marshalErr != nil {
return marshalErr
buf, encodeErr := encodeEntry(e)
if encodeErr != nil {
return encodeErr
}
return b.Put(itob(e.ID), buf)
})
Expand Down Expand Up @@ -125,11 +126,11 @@ func (r *Repo) GetAllEntries() (map[uint64]*entry.Entry, error) {
b := tx.Bucket(r.entryBucket)
c := b.Cursor()
for k, v := c.First(); k != nil; k, v = c.Next() {
var t = entry.Entry{}
if unmarshalErr := json.Unmarshal(v, &t); unmarshalErr != nil {
return unmarshalErr
e, decodeErr := decodeEntry(v)
if decodeErr != nil {
return decodeErr
}
m[t.ID] = &t
m[e.ID] = e
}
return nil
})
Expand Down Expand Up @@ -197,3 +198,19 @@ func NewRepo(file string) *Repo {
entryBucket: []byte("entries"),
}
}

func encodeEntry(e *entry.Entry) ([]byte, error) {
e.Password = encrypt.Encrypt(e.Password)
e.Passphrase = encrypt.Encrypt(e.Passphrase)
return json.Marshal(e)
}

func decodeEntry(bs []byte) (*entry.Entry, error) {
var e = &entry.Entry{}
if err := json.Unmarshal(bs, e); err != nil {
return nil, err
}
e.Password = encrypt.Decrypt(e.Password)
e.Passphrase = encrypt.Decrypt(e.Passphrase)
return e, nil
}

0 comments on commit b6a9005

Please sign in to comment.