diff --git a/internal/encrypt/encrypt.go b/internal/encrypt/encrypt.go new file mode 100644 index 0000000..9c16a85 --- /dev/null +++ b/internal/encrypt/encrypt.go @@ -0,0 +1,164 @@ +// Copyright 2022 Enmotech Inc. All rights reserved. + +package encrypt + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" + "io" + "strings" + "time" + + "github.com/pkg/errors" + + "github.com/vimiix/ssx/internal/lg" + "github.com/vimiix/ssx/internal/utils" +) + +// Encrypt Generates the ciphertext for the given string. +// If the encryption fails, the original characters will be returned. +// If the passed string is empty, return empty directly. +func Encrypt(text string) string { + if text == "" { + return "" + } + + curTime := time.Now().Format("01021504") + salt := md5encode(curTime) + key := salt[:8] + curTime + + cipherText, err := aesEncrypt(text, key) + if err != nil { + lg.Debug("failed to encrypt text '%s': %s", utils.MaskString(text), err) + return text + } + return base64.StdEncoding.EncodeToString([]byte(salt[:8] + shiftEncode(curTime) + cipherText)) +} + +func Decrypt(rawCipher string) string { + if rawCipher == "" { + return "" + } + + dec, err := base64.StdEncoding.DecodeString(rawCipher) + if err != nil { + lg.Debug("failed to base64 decode cipher text '%s': %s", rawCipher, err) + return rawCipher + } + + key := string(dec[:8]) + shiftDecode(string(dec[8:16])) + text := string(dec[16:]) + res, err := aesDecrypt(text, key) + if err != nil { + lg.Debug("failed to decypt cipher '%s': %s", text, err) + return rawCipher + } + return res +} + +func md5encode(s string) string { + h := md5.New() + h.Write([]byte(s)) + return hex.EncodeToString(h.Sum(nil)) +} + +func shiftEncode(s string) string { + rs := make([]string, 0, len(s)) + for _, c := range s[:] { + // start with '<' + rs = append(rs, fmt.Sprintf("%c", c+12)) + } + return strings.Join(rs, "") +} + +func shiftDecode(s string) string { + rs := make([]string, 0, len(s)) + for _, c := range s[:] { + rs = append(rs, fmt.Sprintf("%c", c-12)) + } + return strings.Join(rs, "") +} + +func addBase64Padding(value string) string { + m := len(value) % 4 + if m != 0 { + value += strings.Repeat("=", 4-m) + } + + return value +} + +func removeBase64Padding(value string) string { + return strings.Replace(value, "=", "", -1) +} + +func pad(src []byte) []byte { + padding := aes.BlockSize - len(src)%aes.BlockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padtext...) +} + +func unpad(src []byte) ([]byte, error) { + length := len(src) + unpadding := int(src[length-1]) + + if unpadding > length { + return nil, errors.New("unpad error. This could happen when incorrect encryption key is used") + } + + return src[:(length - unpadding)], nil +} + +func aesEncrypt(text string, key string) (string, error) { + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + + msg := pad([]byte(text)) + ciphertext := make([]byte, aes.BlockSize+len(msg)) + iv := ciphertext[:aes.BlockSize] + if _, err = io.ReadFull(rand.Reader, iv); err != nil { + return "", err + } + + cfb := cipher.NewCFBEncrypter(block, iv) + cfb.XORKeyStream(ciphertext[aes.BlockSize:], msg) + finalMsg := removeBase64Padding(base64.URLEncoding.EncodeToString(ciphertext)) + return finalMsg, nil +} + +func aesDecrypt(text string, key string) (string, error) { + block, err := aes.NewCipher([]byte(key)) + if err != nil { + return "", err + } + + decodedMsg, err := base64.URLEncoding.DecodeString(addBase64Padding(text)) + if err != nil { + return "", err + } + + if (len(decodedMsg) % aes.BlockSize) != 0 { + return "", errors.New("blocksize must be multiple of decoded message length") + } + + iv := decodedMsg[:aes.BlockSize] + msg := decodedMsg[aes.BlockSize:] + + cfb := cipher.NewCFBDecrypter(block, iv) + cfb.XORKeyStream(msg, msg) + + unpadMsg, err := unpad(msg) + if err != nil { + return "", err + } + + return string(unpadMsg), nil +} diff --git a/internal/encrypt/encrypt_test.go b/internal/encrypt/encrypt_test.go new file mode 100644 index 0000000..280451e --- /dev/null +++ b/internal/encrypt/encrypt_test.go @@ -0,0 +1,42 @@ +package encrypt + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncryptDecrypt(t *testing.T) { + tests := []struct { + name string + text string + }{ + {"empty", ""}, + {"regular", "abc123"}, + {"symbol", "!*#$)@>?"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := Decrypt(Encrypt(tt.text)) + assert.Equal(t, tt.text, actual) + }) + } +} + +func TestDecrypt(t *testing.T) { + tests := []struct { + name string + cipher string + expect string + }{ + {"empty", "", ""}, + {"regular", "NmUxODZmYWM8PTxFPUQ9QENIQUc2eGl4T2pEWnQtQ0I2YkE0RkRxRUI0ei1fLUlNMmZKYi1lTFlnQk0=", "abc123"}, + {"plaintext", "abc123", "abc123"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := Decrypt(tt.cipher) + assert.Equal(t, tt.expect, actual) + }) + } +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 6f6a5d2..5e0c459 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -30,3 +30,14 @@ func ExpandHomeDir(path string) string { return filepath.Join(u.HomeDir, path[1:]) } + +func MaskString(s string) string { + mask := "***" + if len(s) == 0 { + return s + } else if len(s) <= 3 { + return s[:1] + mask + } else { + return s[:2] + mask + s[len(s)-1:] + } +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index 71664c1..b90a93a 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -50,3 +50,21 @@ func TestExpandHomeDir(t *testing.T) { }) } } + +func TestMaskString(t *testing.T) { + tests := []struct { + s string + expect string + }{ + {"", ""}, + {"a", "a***"}, + {"ab", "a***"}, + {"abc", "a***"}, + {"abcd", "ab***d"}, + {"abcdefgh", "ab***h"}, + } + for _, tt := range tests { + actual := MaskString(tt.s) + assert.Equal(t, tt.expect, actual) + } +} diff --git a/ssx/bbolt/bbolt.go b/ssx/bbolt/bbolt.go index ca45a2f..4ced470 100644 --- a/ssx/bbolt/bbolt.go +++ b/ssx/bbolt/bbolt.go @@ -7,6 +7,7 @@ import ( "go.etcd.io/bbolt" + "github.com/vimiix/ssx/internal/encrypt" "github.com/vimiix/ssx/internal/errmsg" "github.com/vimiix/ssx/internal/lg" "github.com/vimiix/ssx/ssx/entry" @@ -82,9 +83,9 @@ func (r *Repo) TouchEntry(e *entry.Entry) error { e.UpdateAt = time.Now() } // update - buf, marshalErr := json.Marshal(e) - if marshalErr != nil { - return marshalErr + buf, encodeErr := encodeEntry(e) + if encodeErr != nil { + return encodeErr } return b.Put(itob(e.ID), buf) }) @@ -125,11 +126,11 @@ func (r *Repo) GetAllEntries() (map[uint64]*entry.Entry, error) { b := tx.Bucket(r.entryBucket) c := b.Cursor() for k, v := c.First(); k != nil; k, v = c.Next() { - var t = entry.Entry{} - if unmarshalErr := json.Unmarshal(v, &t); unmarshalErr != nil { - return unmarshalErr + e, decodeErr := decodeEntry(v) + if decodeErr != nil { + return decodeErr } - m[t.ID] = &t + m[e.ID] = e } return nil }) @@ -197,3 +198,19 @@ func NewRepo(file string) *Repo { entryBucket: []byte("entries"), } } + +func encodeEntry(e *entry.Entry) ([]byte, error) { + e.Password = encrypt.Encrypt(e.Password) + e.Passphrase = encrypt.Encrypt(e.Passphrase) + return json.Marshal(e) +} + +func decodeEntry(bs []byte) (*entry.Entry, error) { + var e = &entry.Entry{} + if err := json.Unmarshal(bs, e); err != nil { + return nil, err + } + e.Password = encrypt.Decrypt(e.Password) + e.Passphrase = encrypt.Decrypt(e.Passphrase) + return e, nil +}