1 Star 1 Fork 1

李.斯科特 / gm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
sm2.go 14.10 KB
一键复制 编辑 原始数据 按行查看 历史
李.斯科特 提交于 2023-03-24 15:11 . 初始化
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
package sm2
import (
"bytes"
"crypto/elliptic"
"crypto/rand"
"encoding/asn1"
"encoding/binary"
"errors"
"fmt"
"hash"
"io"
"math/big"
"gitee.com/LeeScott/gm/sm3"
"gitee.com/LeeScott/gm/util"
)
const (
BitSize = 256
KeyBytes = (BitSize + 7) / 8
UnCompress = 0x04
)
type Sm2CipherTextType int32
const (
// 旧标准的密文顺序
C1C2C3 Sm2CipherTextType = 1
// [GM/T 0009-2012]标准规定的顺序
C1C3C2 Sm2CipherTextType = 2
)
var (
sm2H = new(big.Int).SetInt64(1)
sm2SignDefaultUserId = []byte{
0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38,
0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
)
var sm2P256V1 P256V1Curve
type P256V1Curve struct {
*elliptic.CurveParams
A *big.Int
}
type PublicKey struct {
X, Y *big.Int
Curve P256V1Curve
}
type PrivateKey struct {
D *big.Int
Curve P256V1Curve
}
type sm2Signature struct {
R, S *big.Int
}
type sm2CipherC1C3C2 struct {
X, Y *big.Int
C3 []byte
C2 []byte
}
type sm2CipherC1C2C3 struct {
X, Y *big.Int
C2 []byte
C3 []byte
}
func init() {
initSm2P256V1()
}
func initSm2P256V1() {
sm2P, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16)
sm2A, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC", 16)
sm2B, _ := new(big.Int).SetString("28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", 16)
sm2N, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16)
sm2Gx, _ := new(big.Int).SetString("32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", 16)
sm2Gy, _ := new(big.Int).SetString("BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", 16)
sm2P256V1.CurveParams = &elliptic.CurveParams{Name: "SM2-P-256-V1"}
sm2P256V1.P = sm2P
sm2P256V1.A = sm2A
sm2P256V1.B = sm2B
sm2P256V1.N = sm2N
sm2P256V1.Gx = sm2Gx
sm2P256V1.Gy = sm2Gy
sm2P256V1.BitSize = BitSize
}
func GetSm2P256V1() P256V1Curve {
return sm2P256V1
}
func GenerateKey(rand io.Reader) (*PrivateKey, *PublicKey, error) {
priv, x, y, err := elliptic.GenerateKey(sm2P256V1, rand)
if err != nil {
return nil, nil, err
}
privateKey := new(PrivateKey)
privateKey.Curve = sm2P256V1
privateKey.D = new(big.Int).SetBytes(priv)
publicKey := new(PublicKey)
publicKey.Curve = sm2P256V1
publicKey.X = x
publicKey.Y = y
return privateKey, publicKey, nil
}
func RawBytesToPublicKey(bytes []byte) (*PublicKey, error) {
if len(bytes) != KeyBytes*2 {
return nil, errors.New(fmt.Sprintf("Public key raw bytes length must be %d", KeyBytes*2))
}
publicKey := new(PublicKey)
publicKey.Curve = sm2P256V1
publicKey.X = new(big.Int).SetBytes(bytes[:KeyBytes])
publicKey.Y = new(big.Int).SetBytes(bytes[KeyBytes:])
return publicKey, nil
}
func RawBytesToPrivateKey(bytes []byte) (*PrivateKey, error) {
if len(bytes) != KeyBytes {
return nil, errors.New(fmt.Sprintf("Private key raw bytes length must be %d", KeyBytes))
}
privateKey := new(PrivateKey)
privateKey.Curve = sm2P256V1
privateKey.D = new(big.Int).SetBytes(bytes)
return privateKey, nil
}
func (pub *PublicKey) GetUnCompressBytes() []byte {
xBytes := bigIntTo32Bytes(pub.X)
yBytes := bigIntTo32Bytes(pub.Y)
xl := len(xBytes)
yl := len(yBytes)
raw := make([]byte, 1+KeyBytes*2)
raw[0] = UnCompress
if xl > KeyBytes {
copy(raw[1:1+KeyBytes], xBytes[xl-KeyBytes:])
} else if xl < KeyBytes {
copy(raw[1+(KeyBytes-xl):1+KeyBytes], xBytes)
} else {
copy(raw[1:1+KeyBytes], xBytes)
}
if yl > KeyBytes {
copy(raw[1+KeyBytes:], yBytes[yl-KeyBytes:])
} else if yl < KeyBytes {
copy(raw[1+KeyBytes+(KeyBytes-yl):], yBytes)
} else {
copy(raw[1+KeyBytes:], yBytes)
}
return raw
}
func (pub *PublicKey) GetRawBytes() []byte {
raw := pub.GetUnCompressBytes()
return raw[1:]
}
func (pri *PrivateKey) GetRawBytes() []byte {
dBytes := bigIntTo32Bytes(pri.D)
dl := len(dBytes)
if dl > KeyBytes {
raw := make([]byte, KeyBytes)
copy(raw, dBytes[dl-KeyBytes:])
return raw
} else if dl < KeyBytes {
raw := make([]byte, KeyBytes)
copy(raw[KeyBytes-dl:], dBytes)
return raw
} else {
return dBytes
}
}
func CalculatePubKey(priv *PrivateKey) *PublicKey {
pub := new(PublicKey)
pub.Curve = priv.Curve
pub.X, pub.Y = priv.Curve.ScalarBaseMult(priv.D.Bytes())
return pub
}
func nextK(rnd io.Reader, max *big.Int) (*big.Int, error) {
intOne := new(big.Int).SetInt64(1)
var k *big.Int
var err error
for {
k, err = rand.Int(rnd, max)
if err != nil {
return nil, err
}
if k.Cmp(intOne) >= 0 {
return k, err
}
}
}
func xor(data []byte, kdfOut []byte, dRemaining int) {
for i := 0; i != dRemaining; i++ {
data[i] ^= kdfOut[i]
}
}
// 表示SM2 Key的大数比较小时,直接通过Bytes()函数得到的字节数组可能不够32字节,这个时候要补齐成32字节
func bigIntTo32Bytes(bn *big.Int) []byte {
byteArr := bn.Bytes()
byteArrLen := len(byteArr)
if byteArrLen == KeyBytes {
return byteArr
}
byteArr = append(make([]byte, KeyBytes-byteArrLen), byteArr...)
return byteArr
}
func kdf(digest hash.Hash, c1x *big.Int, c1y *big.Int, encData []byte) {
bufSize := 4
if bufSize < digest.Size() {
bufSize = digest.Size()
}
buf := make([]byte, bufSize)
encDataLen := len(encData)
c1xBytes := bigIntTo32Bytes(c1x)
c1yBytes := bigIntTo32Bytes(c1y)
off := 0
ct := uint32(0)
for off < encDataLen {
digest.Reset()
digest.Write(c1xBytes)
digest.Write(c1yBytes)
ct++
binary.BigEndian.PutUint32(buf, ct)
digest.Write(buf[:4])
tmp := digest.Sum(nil)
copy(buf[:bufSize], tmp[:bufSize])
xorLen := encDataLen - off
if xorLen > digest.Size() {
xorLen = digest.Size()
}
xor(encData[off:], buf, xorLen)
off += xorLen
}
}
func notEncrypted(encData []byte, in []byte) bool {
encDataLen := len(encData)
for i := 0; i != encDataLen; i++ {
if encData[i] != in[i] {
return false
}
}
return true
}
func Encrypt(pub *PublicKey, in []byte, cipherTextType Sm2CipherTextType) ([]byte, error) {
c2 := make([]byte, len(in))
copy(c2, in)
var c1 []byte
digest := sm3.New()
var kPBx, kPBy *big.Int
for {
k, err := nextK(rand.Reader, pub.Curve.N)
if err != nil {
return nil, err
}
kBytes := k.Bytes()
c1x, c1y := pub.Curve.ScalarBaseMult(kBytes)
c1 = elliptic.Marshal(pub.Curve, c1x, c1y)
kPBx, kPBy = pub.Curve.ScalarMult(pub.X, pub.Y, kBytes)
kdf(digest, kPBx, kPBy, c2)
if !notEncrypted(c2, in) {
break
}
}
digest.Reset()
digest.Write(bigIntTo32Bytes(kPBx))
digest.Write(in)
digest.Write(bigIntTo32Bytes(kPBy))
c3 := digest.Sum(nil)
c1Len := len(c1)
c2Len := len(c2)
c3Len := len(c3)
result := make([]byte, c1Len+c2Len+c3Len)
if cipherTextType == C1C2C3 {
copy(result[:c1Len], c1)
copy(result[c1Len:c1Len+c2Len], c2)
copy(result[c1Len+c2Len:], c3)
} else if cipherTextType == C1C3C2 {
copy(result[:c1Len], c1)
copy(result[c1Len:c1Len+c3Len], c3)
copy(result[c1Len+c3Len:], c2)
} else {
return nil, errors.New("unknown cipherTextType:" + string(cipherTextType))
}
return result, nil
}
func Decrypt(priv *PrivateKey, in []byte, cipherTextType Sm2CipherTextType) ([]byte, error) {
c1Len := ((priv.Curve.BitSize+7)>>3)*2 + 1
c1 := make([]byte, c1Len)
copy(c1, in[:c1Len])
c1x, c1y := elliptic.Unmarshal(priv.Curve, c1)
sx, sy := priv.Curve.ScalarMult(c1x, c1y, sm2H.Bytes())
if util.IsEcPointInfinity(sx, sy) {
return nil, errors.New("[h]C1 at infinity")
}
c1x, c1y = priv.Curve.ScalarMult(c1x, c1y, priv.D.Bytes())
digest := sm3.New()
c3Len := digest.Size()
c2Len := len(in) - c1Len - c3Len
c2 := make([]byte, c2Len)
c3 := make([]byte, c3Len)
if cipherTextType == C1C2C3 {
copy(c2, in[c1Len:c1Len+c2Len])
copy(c3, in[c1Len+c2Len:])
} else if cipherTextType == C1C3C2 {
copy(c3, in[c1Len:c1Len+c3Len])
copy(c2, in[c1Len+c3Len:])
} else {
return nil, errors.New("unknown cipherTextType:" + string(cipherTextType))
}
kdf(digest, c1x, c1y, c2)
digest.Reset()
digest.Write(bigIntTo32Bytes(c1x))
digest.Write(c2)
digest.Write(bigIntTo32Bytes(c1y))
newC3 := digest.Sum(nil)
if !bytes.Equal(newC3, c3) {
return nil, errors.New("invalid cipher text")
}
return c2, nil
}
func MarshalCipher(in []byte, cipherTextType Sm2CipherTextType) ([]byte, error) {
byteLen := (sm2P256V1.Params().BitSize + 7) >> 3
c1x := make([]byte, byteLen)
c1y := make([]byte, byteLen)
c2Len := len(in) - (1 + byteLen*2) - sm3.DigestLength
c2 := make([]byte, c2Len)
c3 := make([]byte, sm3.DigestLength)
pos := 1
copy(c1x, in[pos:pos+byteLen])
pos += byteLen
copy(c1y, in[pos:pos+byteLen])
pos += byteLen
nc1x := new(big.Int).SetBytes(c1x)
nc1y := new(big.Int).SetBytes(c1y)
if cipherTextType == C1C2C3 {
copy(c2, in[pos:pos+c2Len])
pos += c2Len
copy(c3, in[pos:pos+sm3.DigestLength])
result, err := asn1.Marshal(sm2CipherC1C2C3{nc1x, nc1y, c2, c3})
if err != nil {
return nil, err
}
return result, nil
} else if cipherTextType == C1C3C2 {
copy(c3, in[pos:pos+sm3.DigestLength])
pos += sm3.DigestLength
copy(c2, in[pos:pos+c2Len])
result, err := asn1.Marshal(sm2CipherC1C3C2{nc1x, nc1y, c3, c2})
if err != nil {
return nil, err
}
return result, nil
} else {
return nil, errors.New("unknown cipherTextType:" + string(cipherTextType))
}
}
func UnmarshalCipher(in []byte, cipherTextType Sm2CipherTextType) (out []byte, err error) {
if cipherTextType == C1C2C3 {
cipher := new(sm2CipherC1C2C3)
_, err = asn1.Unmarshal(in, cipher)
if err != nil {
return nil, err
}
c1xBytes := bigIntTo32Bytes(cipher.X)
c1yBytes := bigIntTo32Bytes(cipher.Y)
c1xLen := len(c1xBytes)
c1yLen := len(c1yBytes)
c2Len := len(cipher.C2)
c3Len := len(cipher.C3)
result := make([]byte, 1+c1xLen+c1yLen+c2Len+c3Len)
pos := 0
result[pos] = UnCompress
pos += 1
copy(result[pos:pos+c1xLen], c1xBytes)
pos += c1xLen
copy(result[pos:pos+c1yLen], c1yBytes)
pos += c1yLen
copy(result[pos:pos+c2Len], cipher.C2)
pos += c2Len
copy(result[pos:pos+c3Len], cipher.C3)
return result, nil
} else if cipherTextType == C1C3C2 {
cipher := new(sm2CipherC1C3C2)
_, err = asn1.Unmarshal(in, cipher)
if err != nil {
return nil, err
}
c1xBytes := bigIntTo32Bytes(cipher.X)
c1yBytes := bigIntTo32Bytes(cipher.Y)
c1xLen := len(c1xBytes)
c1yLen := len(c1yBytes)
c2Len := len(cipher.C2)
c3Len := len(cipher.C3)
result := make([]byte, 1+c1xLen+c1yLen+c2Len+c3Len)
pos := 0
result[pos] = UnCompress
pos += 1
copy(result[pos:pos+c1xLen], c1xBytes)
pos += c1xLen
copy(result[pos:pos+c1yLen], c1yBytes)
pos += c1yLen
copy(result[pos:pos+c3Len], cipher.C3)
pos += c3Len
copy(result[pos:pos+c2Len], cipher.C2)
return result, nil
} else {
return nil, errors.New("unknown cipherTextType:" + string(cipherTextType))
}
}
func getZ(digest hash.Hash, curve *P256V1Curve, pubX *big.Int, pubY *big.Int, userId []byte) []byte {
digest.Reset()
userIdLen := uint16(len(userId) * 8)
var userIdLenBytes [2]byte
binary.BigEndian.PutUint16(userIdLenBytes[:], userIdLen)
digest.Write(userIdLenBytes[:])
if userId != nil && len(userId) > 0 {
digest.Write(userId)
}
digest.Write(bigIntTo32Bytes(curve.A))
digest.Write(bigIntTo32Bytes(curve.B))
digest.Write(bigIntTo32Bytes(curve.Gx))
digest.Write(bigIntTo32Bytes(curve.Gy))
digest.Write(bigIntTo32Bytes(pubX))
digest.Write(bigIntTo32Bytes(pubY))
return digest.Sum(nil)
}
func calculateE(digest hash.Hash, curve *P256V1Curve, pubX *big.Int, pubY *big.Int, userId []byte, src []byte) *big.Int {
z := getZ(digest, curve, pubX, pubY, userId)
digest.Reset()
digest.Write(z)
digest.Write(src)
eHash := digest.Sum(nil)
return new(big.Int).SetBytes(eHash)
}
func MarshalSign(r, s *big.Int) ([]byte, error) {
result, err := asn1.Marshal(sm2Signature{r, s})
if err != nil {
return nil, err
}
return result, nil
}
func UnmarshalSign(sign []byte) (r, s *big.Int, err error) {
sm2Sign := new(sm2Signature)
_, err = asn1.Unmarshal(sign, sm2Sign)
if err != nil {
return nil, nil, err
}
return sm2Sign.R, sm2Sign.S, nil
}
func SignToRS(priv *PrivateKey, userId []byte, in []byte) (r, s *big.Int, err error) {
digest := sm3.New()
pubX, pubY := priv.Curve.ScalarBaseMult(priv.D.Bytes())
if userId == nil {
userId = sm2SignDefaultUserId
}
e := calculateE(digest, &priv.Curve, pubX, pubY, userId, in)
intZero := new(big.Int).SetInt64(0)
intOne := new(big.Int).SetInt64(1)
for {
var k *big.Int
var err error
for {
k, err = nextK(rand.Reader, priv.Curve.N)
if err != nil {
return nil, nil, err
}
px, _ := priv.Curve.ScalarBaseMult(k.Bytes())
r = util.Add(e, px)
r = util.Mod(r, priv.Curve.N)
rk := new(big.Int).Set(r)
rk = rk.Add(rk, k)
if r.Cmp(intZero) != 0 && rk.Cmp(priv.Curve.N) != 0 {
break
}
}
dPlus1ModN := util.Add(priv.D, intOne)
dPlus1ModN = util.ModInverse(dPlus1ModN, priv.Curve.N)
s = util.Mul(r, priv.D)
s = util.Sub(k, s)
s = util.Mod(s, priv.Curve.N)
s = util.Mul(dPlus1ModN, s)
s = util.Mod(s, priv.Curve.N)
if s.Cmp(intZero) != 0 {
break
}
}
return r, s, nil
}
// 签名结果为DER编码的字节数组
func Sign(priv *PrivateKey, userId []byte, in []byte) ([]byte, error) {
r, s, err := SignToRS(priv, userId, in)
if err != nil {
return nil, err
}
return MarshalSign(r, s)
}
func VerifyByRS(pub *PublicKey, userId []byte, src []byte, r, s *big.Int) bool {
intOne := new(big.Int).SetInt64(1)
if r.Cmp(intOne) == -1 || r.Cmp(pub.Curve.N) >= 0 {
return false
}
if s.Cmp(intOne) == -1 || s.Cmp(pub.Curve.N) >= 0 {
return false
}
digest := sm3.New()
if userId == nil {
userId = sm2SignDefaultUserId
}
e := calculateE(digest, &pub.Curve, pub.X, pub.Y, userId, src)
intZero := new(big.Int).SetInt64(0)
t := util.Add(r, s)
t = util.Mod(t, pub.Curve.N)
if t.Cmp(intZero) == 0 {
return false
}
sgx, sgy := pub.Curve.ScalarBaseMult(s.Bytes())
tpx, tpy := pub.Curve.ScalarMult(pub.X, pub.Y, t.Bytes())
x, y := pub.Curve.Add(sgx, sgy, tpx, tpy)
if util.IsEcPointInfinity(x, y) {
return false
}
expectedR := util.Add(e, x)
expectedR = util.Mod(expectedR, pub.Curve.N)
return expectedR.Cmp(r) == 0
}
// 输入签名须为DER编码的字节数组
func Verify(pub *PublicKey, userId []byte, src []byte, sign []byte) bool {
r, s, err := UnmarshalSign(sign)
if err != nil {
return false
}
return VerifyByRS(pub, userId, src, r, s)
}
Go
1
https://gitee.com/LeeScott/gm.git
git@gitee.com:LeeScott/gm.git
LeeScott
gm
gm
v1.0.0

搜索帮助