1 Star 0 Fork 0

dqc/gmsm

Create your Gitee Account
Explore and code with more than 14 million developers,Free private repositories !:)
Sign up
文件
Clone or Download
sm2.go 15.61 KB
Copy Edit Raw Blame History
dqc authored 2022-11-05 23:29 +08:00 . update sm2/sm2.go.
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
package sm2
import (
"bytes"
"crypto"
"crypto/elliptic"
"crypto/rand"
"encoding/asn1"
"encoding/binary"
"errors"
"io"
"math/big"
"gitee.com/dqc_123/gmsm/sm3"
)
var (
default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
C1C3C2=0
C1C2C3=1
)
type PublicKey struct {
elliptic.Curve
X, Y *big.Int
}
type PrivateKey struct {
PublicKey
D *big.Int
}
type sm2Signature struct {
R, S *big.Int
}
type sm2Cipher struct {
XCoordinate *big.Int
YCoordinate *big.Int
HASH []byte
CipherText []byte
}
// The SM2's private key contains the public key
func (priv *PrivateKey) Public() crypto.PublicKey {
return &priv.PublicKey
}
var errZeroParam = errors.New("zero parameter")
var one = new(big.Int).SetInt64(1)
var two = new(big.Int).SetInt64(2)
// sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) {
r, s, err := Sm2Sign(priv, msg, nil, random)
if err != nil {
return nil, err
}
return asn1.Marshal(sm2Signature{r, s})
}
func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
var sm2Sign sm2Signature
_, err := asn1.Unmarshal(sign, &sm2Sign)
if err != nil {
return false
}
return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S)
}
func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) {
if len(uid) == 0 {
uid = default_uid
}
za, err := ZA(pub, uid)
if err != nil {
return nil, err
}
e, err := msgHash(za, msg)
if err != nil {
return nil, err
}
return e.Bytes(), nil
}
//****************************Encryption algorithm****************************//
func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) {
return EncryptAsn1(pub, data, random)
}
func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) {
return DecryptAsn1(priv, data)
}
//**************************Key agreement algorithm**************************//
// KeyExchangeB 协商第二部,用户B调用, 返回共享密钥k
func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) {
return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false)
}
// KeyExchangeA 协商第二部,用户A调用,返回共享密钥k
func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) {
return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true)
}
//****************************************************************************//
func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) {
digest, err := priv.PublicKey.Sm3Digest(msg, uid)
if err != nil {
return nil, nil, err
}
e := new(big.Int).SetBytes(digest)
c := priv.PublicKey.Curve
N := c.Params().N
if N.Sign() == 0 {
return nil, nil, errZeroParam
}
var k *big.Int
for { // 调整算法细节以实现SM2
for {
k, err = randFieldElement(c, random)
if err != nil {
r = nil
return
}
r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
r.Add(r, e)
r.Mod(r, N)
if r.Sign() != 0 {
if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
break
}
}
}
rD := new(big.Int).Mul(priv.D, r)
s = new(big.Int).Sub(k, rD)
d1 := new(big.Int).Add(priv.D, one)
d1Inv := new(big.Int).ModInverse(d1, N)
s.Mul(s, d1Inv)
s.Mod(s, N)
if s.Sign() != 0 {
break
}
}
return
}
func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
c := pub.Curve
N := c.Params().N
one := new(big.Int).SetInt64(1)
if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
return false
}
if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
return false
}
if len(uid) == 0 {
uid = default_uid
}
za, err := ZA(pub, uid)
if err != nil {
return false
}
e, err := msgHash(za, msg)
if err != nil {
return false
}
t := new(big.Int).Add(r, s)
t.Mod(t, N)
if t.Sign() == 0 {
return false
}
var x *big.Int
x1, y1 := c.ScalarBaseMult(s.Bytes())
x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
x, _ = c.Add(x1, y1, x2, y2)
x.Add(x, e)
x.Mod(x, N)
return x.Cmp(r) == 0
}
/*
za, err := ZA(pub, uid)
if err != nil {
return
}
e, err := msgHash(za, msg)
hash=e.getBytes()
*/
func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
c := pub.Curve
N := c.Params().N
if r.Sign() <= 0 || s.Sign() <= 0 {
return false
}
if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
return false
}
// 调整算法细节以实现SM2
t := new(big.Int).Add(r, s)
t.Mod(t, N)
if t.Sign() == 0 {
return false
}
var x *big.Int
x1, y1 := c.ScalarBaseMult(s.Bytes())
x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
x, _ = c.Add(x1, y1, x2, y2)
e := new(big.Int).SetBytes(hash)
x.Add(x, e)
x.Mod(x, N)
return x.Cmp(r) == 0
}
/*
* sm2密文结构如下:
* x
* y
* hash
* CipherText
*/
func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, error) {
length := len(data)
for {
c := []byte{}
curve := pub.Curve
k, err := randFieldElement(curve, random)
if err != nil {
return nil, err
}
x1, y1 := curve.ScalarBaseMult(k.Bytes())
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
x1Buf := x1.Bytes()
y1Buf := y1.Bytes()
x2Buf := x2.Bytes()
y2Buf := y2.Bytes()
if n := len(x1Buf); n < 32 {
x1Buf = append(zeroByteSlice()[:32-n], x1Buf...)
}
if n := len(y1Buf); n < 32 {
y1Buf = append(zeroByteSlice()[:32-n], y1Buf...)
}
if n := len(x2Buf); n < 32 {
x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
}
if n := len(y2Buf); n < 32 {
y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
}
c = append(c, x1Buf...) // x分量
c = append(c, y1Buf...) // y分量
tm := []byte{}
tm = append(tm, x2Buf...)
tm = append(tm, data...)
tm = append(tm, y2Buf...)
h := sm3.Sm3Sum(tm)
c = append(c, h...)
ct, ok := kdf(length, x2Buf, y2Buf) // 密文
if !ok {
continue
}
c = append(c, ct...)
for i := 0; i < length; i++ {
c[96+i] ^= data[i]
}
switch mode{
case C1C3C2:
return append([]byte{0x04}, c...), nil
case C1C2C3:
c1 := make([]byte, 64)
c2 := make([]byte, len(c) - 96)
c3 := make([]byte, 32)
copy(c1, c[:64])//x1,y1
copy(c3, c[64:96])//hash
copy(c2, c[96:])//密文
ciphertext := []byte{}
ciphertext = append(ciphertext, c1...)
ciphertext = append(ciphertext, c2...)
ciphertext = append(ciphertext, c3...)
return append([]byte{0x04}, ciphertext...), nil
default:
return append([]byte{0x04}, c...), nil
}
}
}
func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) {
switch mode {
case C1C3C2:
data = data[1:]
case C1C2C3:
data = data[1:]
c1 := make([]byte, 64)
c2 := make([]byte, len(data) - 96)
c3 := make([]byte, 32)
copy(c1, data[:64])//x1,y1
copy(c2, data[64:len(data) - 32])//密文
copy(c3, data[len(data) - 32:])//hash
c := []byte{}
c = append(c, c1...)
c = append(c, c3...)
c = append(c, c2...)
data = c
default:
data = data[1:]
}
length := len(data) - 96
curve := priv.Curve
x := new(big.Int).SetBytes(data[:32])
y := new(big.Int).SetBytes(data[32:64])
x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
x2Buf := x2.Bytes()
y2Buf := y2.Bytes()
if n := len(x2Buf); n < 32 {
x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
}
if n := len(y2Buf); n < 32 {
y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
}
c, ok := kdf(length, x2Buf, y2Buf)
if !ok {
return nil, errors.New("Decrypt: failed to decrypt")
}
for i := 0; i < length; i++ {
c[i] ^= data[i+96]
}
tm := []byte{}
tm = append(tm, x2Buf...)
tm = append(tm, c...)
tm = append(tm, y2Buf...)
h := sm3.Sm3Sum(tm)
if bytes.Compare(h, data[64:96]) != 0 {
return c, errors.New("Decrypt: failed to decrypt")
}
return c, nil
}
// keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串
// klen: 密钥长度
// ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识
// pri: 函数调用者的密钥
// pub: 对方的公钥
// rpri: 函数调用者生成的临时SM2密钥
// rpub: 对方发来的临时SM2公钥
// thisIsA: 如果是A调用,文档中的协商第三步,设置为true,否则设置为false
// 返回 k 为klen长度的字节串
func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) {
curve := P256Sm2()
N := curve.Params().N
x2hat := keXHat(rpri.PublicKey.X)
x2rb := new(big.Int).Mul(x2hat, rpri.D)
tbt := new(big.Int).Add(pri.D, x2rb)
tb := new(big.Int).Mod(tbt, N)
if !curve.IsOnCurve(rpub.X, rpub.Y) {
err = errors.New("Ra not on curve")
return
}
x1hat := keXHat(rpub.X)
ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes())
vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1)
vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes())
pza := pub
if thisISA {
pza = &pri.PublicKey
}
za, err := ZA(pza, ida)
if err != nil {
return
}
zero := new(big.Int)
if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 {
err = errors.New("V is infinite")
}
pzb := pub
if !thisISA {
pzb = &pri.PublicKey
}
zb, err := ZA(pzb, idb)
k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb)
if !ok {
err = errors.New("kdf: zero key")
return
}
h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes())
if !thisISA {
h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes())
}
hash := sm3.Sm3Sum(h1)
h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash)
S1 := sm3.Sm3Sum(h2)
h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash)
S2 := sm3.Sm3Sum(h3)
return k, S1, S2, nil
}
func msgHash(za, msg []byte) (*big.Int, error) {
e := sm3.New()
e.Write(za)
e.Write(msg)
return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
}
// ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
za := sm3.New()
uidLen := len(uid)
if uidLen >= 8192 {
return []byte{}, errors.New("SM2: uid too large")
}
Entla := uint16(8 * uidLen)
za.Write([]byte{byte((Entla >> 8) & 0xFF)})
za.Write([]byte{byte(Entla & 0xFF)})
if uidLen > 0 {
za.Write(uid)
}
za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
za.Write(sm2P256.B.Bytes())
za.Write(sm2P256.Gx.Bytes())
za.Write(sm2P256.Gy.Bytes())
xBuf := pub.X.Bytes()
yBuf := pub.Y.Bytes()
if n := len(xBuf); n < 32 {
xBuf = append(zeroByteSlice()[:32-n], xBuf...)
}
if n := len(yBuf); n < 32 {
yBuf = append(zeroByteSlice()[:32-n], yBuf...)
}
za.Write(xBuf)
za.Write(yBuf)
return za.Sum(nil)[:32], nil
}
// 32byte
func zeroByteSlice() []byte {
return []byte{
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
}
}
/*
sm2加密,返回asn.1编码格式的密文内容
*/
func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) {
cipher, err := Encrypt(pub, data, rand,C1C3C2)
if err != nil {
return nil, err
}
return CipherMarshal(cipher)
}
/*
sm2解密,解析asn.1编码格式的密文内容
*/
func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) {
cipher, err := CipherUnmarshal(data)
if err != nil {
return nil, err
}
return Decrypt(pub, cipher,C1C3C2)
}
/*
*sm2密文转asn.1编码格式
*sm2密文结构如下:
* x
* y
* hash
* CipherText
*/
func CipherMarshal(data []byte) ([]byte, error) {
data = data[1:]
x := new(big.Int).SetBytes(data[:32])
y := new(big.Int).SetBytes(data[32:64])
hash := data[64:96]
cipherText := data[96:]
return asn1.Marshal(sm2Cipher{x, y, hash, cipherText})
}
/*
sm2密文asn.1编码格式转C1|C3|C2拼接格式
*/
func CipherUnmarshal(data []byte) ([]byte, error) {
var cipher sm2Cipher
_, err := asn1.Unmarshal(data, &cipher)
if err != nil {
return nil, err
}
x := cipher.XCoordinate.Bytes()
y := cipher.YCoordinate.Bytes()
hash := cipher.HASH
if err != nil {
return nil, err
}
cipherText := cipher.CipherText
if err != nil {
return nil, err
}
if n := len(x); n < 32 {
x = append(zeroByteSlice()[:32-n], x...)
}
if n := len(y); n < 32 {
y = append(zeroByteSlice()[:32-n], y...)
}
c := []byte{}
c = append(c, x...) // x分量
c = append(c, y...) // y分
c = append(c, hash...) // x分量
c = append(c, cipherText...) // y分
return append([]byte{0x04}, c...), nil
}
// keXHat 计算 x = 2^w + (x & (2^w-1))
// 密钥协商算法辅助函数
func keXHat(x *big.Int) (xul *big.Int) {
buf := x.Bytes()
for i := 0; i < len(buf)-16; i++ {
buf[i] = 0
}
if len(buf) >= 16 {
c := buf[len(buf)-16]
buf[len(buf)-16] = c & 0x7f
}
r := new(big.Int).SetBytes(buf)
_2w := new(big.Int).SetBytes([]byte{
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
return r.Add(r, _2w)
}
func BytesCombine(pBytes ...[]byte) []byte {
len := len(pBytes)
s := make([][]byte, len)
for index := 0; index < len; index++ {
s[index] = pBytes[index]
}
sep := []byte("")
return bytes.Join(s, sep)
}
func intToBytes(x int) []byte {
var buf = make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(x))
return buf
}
func kdf(length int, x ...[]byte) ([]byte, bool) {
var c []byte
ct := 1
h := sm3.New()
for i, j := 0, (length+31)/32; i < j; i++ {
h.Reset()
for _, xx := range x {
h.Write(xx)
}
h.Write(intToBytes(ct))
hash := h.Sum(nil)
if i+1 == j && length%32 != 0 {
c = append(c, hash[:length%32]...)
} else {
c = append(c, hash...)
}
ct++
}
for i := 0; i < length; i++ {
if c[i] != 0 {
return c, true
}
}
return c, false
}
func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
if random == nil {
random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
}
params := c.Params()
b := make([]byte, params.BitSize/8+8)
_, err = io.ReadFull(random, b)
if err != nil {
return
}
k = new(big.Int).SetBytes(b)
n := new(big.Int).Sub(params.N, one)
k.Mod(k, n)
k.Add(k, one)
return
}
func GenerateKey(random io.Reader) (*PrivateKey, error) {
c := P256Sm2()
if random == nil {
random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
}
params := c.Params()
b := make([]byte, params.BitSize/8+8)
_, err := io.ReadFull(random, b)
if err != nil {
return nil, err
}
k := new(big.Int).SetBytes(b)
n := new(big.Int).Sub(params.N, two)
k.Mod(k, n)
k.Add(k, one)
priv := new(PrivateKey)
priv.PublicKey.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
return priv, nil
}
type zr struct {
io.Reader
}
func (z *zr) Read(dst []byte) (n int, err error) {
for i := range dst {
dst[i] = 0
}
return len(dst), nil
}
var zeroReader = &zr{}
func getLastBit(a *big.Int) uint {
return a.Bit(0)
}
// crypto.Decrypter
func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) {
return Decrypt(priv, msg,C1C3C2)
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/dqc_123/gmsm.git
git@gitee.com:dqc_123/gmsm.git
dqc_123
gmsm
gmsm
3b37d29bc263

Search