1 Star 0 Fork 0

sqos/beats

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
parse.go 13.80 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
package tls
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/pem"
"fmt"
"strings"
"github.com/elastic/beats/libbeat/common"
"github.com/elastic/beats/libbeat/common/streambuf"
"github.com/elastic/beats/libbeat/logp"
)
type direction uint8
const (
dirUnknown direction = iota
dirClient
dirServer
)
const (
maxTLSRecordLength = (1 << 14) + 2048
// For safety, ignore handshake messages longer than 64k (same as stdlib)
maxHandshakeSize = 1 << 16
recordHeaderSize = 5
handshakeHeaderSize = 4
helloHeaderLength = 7
randomDataLength = 28
)
type recordType uint8
const (
recordTypeChangeCipherSpec recordType = 20
recordTypeAlert = 21
recordTypeHandshake = 22
recordTypeApplicationData = 23
)
type handshakeType uint8
const (
helloRequest handshakeType = 0
clientHello = 1
serverHello = 2
certificate = 11
serverKeyExchange = 12
certificateRequest = 13
clientKeyExchange = 16
)
type parserResult int8
const (
resultOK parserResult = iota
resultFailed
resultMore
resultEncrypted
)
type tlsTicket struct {
present bool
value string
}
type parser struct {
// Buffer to accumulate records until a full handshake message
// is received
handshakeBuf streambuf.Buffer
direction direction
alerts []alert
certificates []*x509.Certificate
hello *helloMessage
// If this end of the connection (server) asked the other end (client)
// for a certificate
certRequested bool
// If a key-exchange message has been sent. Used to detect session resumption
keyExchanged bool
}
type tlsVersion struct {
major, minor uint8
}
type recordHeader struct {
recordType recordType
version tlsVersion
length uint16
}
type handshakeHeader struct {
handshakeType handshakeType
length int
}
type helloMessage struct {
version tlsVersion
timestamp uint32
sessionID string
ticket tlsTicket
supported struct {
cipherSuites []cipherSuite
compression []compressionMethod
}
selected struct {
cipherSuite cipherSuite
compression compressionMethod
}
extensions common.MapStr
}
func readRecordHeader(buf *streambuf.Buffer) (*recordHeader, error) {
var (
header recordHeader
err error
record uint8
)
if record, err = buf.ReadNetUint8At(0); err != nil {
return nil, err
}
header.recordType = recordType(record)
if header.version.major, err = buf.ReadNetUint8At(1); err != nil {
return nil, err
}
if header.version.minor, err = buf.ReadNetUint8At(2); err != nil {
return nil, err
}
if header.length, err = buf.ReadNetUint16At(3); err != nil {
return nil, err
}
return &header, nil
}
func readHandshakeHeader(buf *streambuf.Buffer) (*handshakeHeader, error) {
var err error
var len8, typ uint8
var len16 uint16
if typ, err = buf.ReadNetUint8At(0); err != nil {
return nil, err
}
if len8, err = buf.ReadNetUint8At(1); err != nil {
return nil, err
}
if len16, err = buf.ReadNetUint16At(2); err != nil {
return nil, err
}
return &handshakeHeader{handshakeType(typ),
int(len16) | (int(len8) << 16)}, nil
}
func (header *recordHeader) String() string {
return fmt.Sprintf("recordHeader type[%v] version[%v] length[%d]",
header.recordType, header.version, header.length)
}
func (header *recordHeader) isValid() bool {
return header.version.major == 3 && header.length <= maxTLSRecordLength
}
func (hello helloMessage) toMap() common.MapStr {
m := common.MapStr{
"version": fmt.Sprintf("%d.%d", hello.version.major, hello.version.minor),
}
if len(hello.sessionID) != 0 {
m["session_id"] = hello.sessionID
}
if len(hello.supported.cipherSuites) > 0 || len(hello.supported.compression) > 0 {
ciphers := make([]string, len(hello.supported.cipherSuites))
for idx, code := range hello.supported.cipherSuites {
ciphers[idx] = code.String()
}
m["supported_ciphers"] = ciphers
comp := make([]string, len(hello.supported.compression))
for idx, code := range hello.supported.compression {
comp[idx] = code.String()
}
m["supported_compression_methods"] = comp
} else {
m["selected_cipher"] = hello.selected.cipherSuite.String()
m["selected_compression_method"] = hello.selected.compression.String()
}
if hello.extensions != nil {
m["extensions"] = hello.extensions
}
return m
}
func (parser *parser) parse(buf *streambuf.Buffer) parserResult {
for buf.Avail(recordHeaderSize) {
header, err := readRecordHeader(buf)
if err != nil || !header.isValid() {
if err != nil {
logp.Warn("internal buffer error: %v", err)
}
return resultFailed
}
limit := recordHeaderSize + int(header.length)
if !buf.Avail(limit) {
// wait for complete record
return resultMore
}
switch header.recordType {
case recordTypeChangeCipherSpec: // single message of size 1 (byte 1)
if isDebug {
debugf("handshake completed")
}
// discard remaining data for this stream (encrypted)
buf.Advance(buf.Len())
return resultEncrypted
case recordTypeHandshake:
if isDebug {
debugf("got handshake record of size %d", header.length)
}
if err = parser.bufferHandshake(buf, int(header.length)); err != nil {
logp.Warn("Error parsing handshake message: %v", err)
return resultFailed
}
case recordTypeAlert:
if err = parser.parseAlert(newBufferView(buf, recordHeaderSize, int(header.length))); err != nil {
logp.Warn("Error parsing alert message: %v", err)
return resultFailed
}
case recordTypeApplicationData:
// TODO: Request / Response analytics
if isDebug {
debugf("ignoring application data length %d", header.length)
}
default:
if isDebug {
debugf("ignoring record type %d length %d", header.recordType, header.length)
}
}
buf.Advance(limit)
}
if buf.Len() == 0 {
return resultOK
}
return resultMore
}
func (parser *parser) bufferHandshake(buf *streambuf.Buffer, length int) error {
// TODO: parse in-place if message in received buffer is complete
if err := parser.handshakeBuf.Append(buf.Bytes()[recordHeaderSize : recordHeaderSize+length]); err != nil {
logp.Warn("failed appending to buffer: %v", err)
// Discard buffer
parser.handshakeBuf.Init(nil, false)
return err
}
for parser.handshakeBuf.Avail(handshakeHeaderSize) {
// type
header, err := readHandshakeHeader(&parser.handshakeBuf)
if err != nil {
logp.Warn("read failed: %v", err)
parser.handshakeBuf.Init(nil, false)
return err
}
if header.length > maxHandshakeSize {
// Discard buffer
parser.handshakeBuf.Init(nil, false)
return fmt.Errorf("message too large (%d bytes)", header.length)
}
limit := handshakeHeaderSize + header.length
if limit > parser.handshakeBuf.Len() {
break
}
if !parser.parseHandshake(header.handshakeType,
bufferView{&parser.handshakeBuf, handshakeHeaderSize, limit}) {
parser.handshakeBuf.Advance(limit)
return fmt.Errorf("bad handshake %+v", header)
}
parser.handshakeBuf.Advance(limit)
}
if parser.handshakeBuf.Len() == 0 {
parser.handshakeBuf.Reset()
}
return nil
}
func (parser *parser) setDirection(dir direction) {
if parser.direction != dir && parser.direction != dirUnknown {
logp.Warn("client/server identification mismatch")
}
parser.direction = dir
}
func (parser *parser) parseHandshake(handshakeType handshakeType, buffer bufferView) bool {
if isDebug {
debugf("got handshake message %v [%d]", handshakeType, buffer.length())
}
switch handshakeType {
case helloRequest:
parser.setDirection(dirServer)
return parseHelloRequest(buffer)
case clientHello:
parser.setDirection(dirClient)
if parser.hello = parseClientHello(buffer); parser.hello == nil {
return false
}
return true
case serverHello:
parser.setDirection(dirServer)
if parser.hello = parseServerHello(buffer); parser.hello == nil {
return false
}
return true
case certificate:
certs := parseCertificates(buffer)
parser.certificates = append(parser.certificates, certs...)
case certificateRequest:
parser.setDirection(dirServer)
parser.certRequested = true
case clientKeyExchange:
parser.setDirection(dirClient)
parser.keyExchanged = true
case serverKeyExchange:
parser.setDirection(dirServer)
parser.keyExchanged = true
}
return true
}
func parseHelloRequest(buffer bufferView) bool {
if buffer.length() != 0 {
logp.Warn("non-empty hello request")
}
return true
}
func parseCommonHello(buffer bufferView, dest *helloMessage) (int, bool) {
var sessionIDLength uint8
if !buffer.read8(0, &dest.version.major) ||
!buffer.read8(1, &dest.version.minor) ||
!buffer.read32Net(2, &dest.timestamp) ||
// ignore 28 random bytes
!buffer.read8(6+randomDataLength, &sessionIDLength) {
logp.Warn("failed reading hello message")
return 0, false
}
if dest.version.major != 3 {
logp.Warn("Not a TLS hello (reported version %d.%d)",
dest.version.major, dest.version.minor)
return 0, false
}
if sessionIDLength > 32 {
logp.Warn("Not a TLS hello (session id length %d out of bounds)", sessionIDLength)
return 0, false
}
if bytes := buffer.readBytes(7+randomDataLength, int(sessionIDLength)); len(bytes) == int(sessionIDLength) {
dest.sessionID = hex.EncodeToString(bytes)
} else {
logp.Warn("Not a TLS hello (failed reading session ID)")
return 0, false
}
return helloHeaderLength + randomDataLength + int(sessionIDLength), true
}
func (hello *helloMessage) parseExtensions(buffer bufferView) {
hello.extensions = parseExtensions(buffer)
if ticket, err := hello.extensions.GetValue("session_ticket"); err == nil {
if value, ok := ticket.(string); ok {
hello.ticket.present = true
hello.ticket.value = value
} else {
logp.Err("tls ticket data type error")
}
}
}
func parseClientHello(buffer bufferView) *helloMessage {
var result helloMessage
pos, ok := parseCommonHello(buffer, &result)
if !ok {
return nil
}
var cipherSuitesLength uint16
if !buffer.read16Net(pos, &cipherSuitesLength) {
logp.Warn("failed parsing client hello cipher suite length")
return nil
}
for base := pos + 2; base < pos+2+int(cipherSuitesLength); base += 2 {
var cipher uint16
if !buffer.read16Net(base, &cipher) {
logp.Warn("failed parsing client hello cipher suite")
return nil
}
result.supported.cipherSuites = append(result.supported.cipherSuites, cipherSuite(cipher))
}
pos += 2 + int(cipherSuitesLength)
var compMethodsLength uint8
if !buffer.read8(pos, &compMethodsLength) {
logp.Warn("failed parsing client hello compression methods length")
return nil
}
limit := pos + 1 + int(compMethodsLength)
for base := pos + 1; base < limit; base++ {
var method uint8
if !buffer.read8(base, &method) {
logp.Warn("failed parsing client hello compression methods")
return nil
}
result.supported.compression = append(result.supported.compression, compressionMethod(method))
}
result.parseExtensions(buffer.subview(limit, buffer.limit-limit))
return &result
}
func parseServerHello(buffer bufferView) *helloMessage {
var result helloMessage
pos, ok := parseCommonHello(buffer, &result)
if !ok {
return nil
}
var cipher uint16
var compression uint8
if !buffer.read16Net(pos, &cipher) ||
!buffer.read8(pos+2, &compression) {
return nil
}
result.selected.cipherSuite = cipherSuite(cipher)
result.selected.compression = compressionMethod(compression)
result.parseExtensions(buffer.subview(pos+3, buffer.limit-pos-3))
return &result
}
func parseCertificates(buffer bufferView) []*x509.Certificate {
var totalLen uint32
if !buffer.read24Net(0, &totalLen) || int(totalLen+3) != buffer.length() {
return nil
}
var certs []*x509.Certificate
for pos, limit := 3, int(totalLen)+3; pos+3 <= limit; {
var certLen uint32
if !buffer.read24Net(pos, &certLen) || pos+3+int(certLen) > limit {
return nil
}
cert := buffer.readBytes(pos+3, int(certLen))
if len(cert) != int(certLen) {
return nil
}
parsed, err := x509.ParseCertificate(cert)
if err != nil {
return nil
}
certs = append(certs, parsed)
pos += 3 + int(certLen)
}
return certs
}
func (version tlsVersion) String() string {
return fmt.Sprintf("%d.%d", version.major, version.minor)
}
func certToMap(cert *x509.Certificate, includeRaw bool) common.MapStr {
certMap := common.MapStr{
"signature_algorithm": cert.SignatureAlgorithm.String(),
"public_key_algorithm": toString(cert.PublicKeyAlgorithm),
"version": cert.Version,
"serial_number": cert.SerialNumber.Text(10),
"issuer": toMap(&cert.Issuer),
"subject": toMap(&cert.Subject),
"not_before": cert.NotBefore,
"not_after": cert.NotAfter,
}
san := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses)+len(cert.EmailAddresses))
san = append(append(san, cert.DNSNames...), cert.EmailAddresses...)
for _, ip := range cert.IPAddresses {
san = append(san, ip.String())
}
if len(san) > 0 {
certMap["alternative_names"] = san
}
if includeRaw {
block := pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
certMap["raw"] = string(pem.EncodeToMemory(&block))
}
return certMap
}
func toMap(name *pkix.Name) common.MapStr {
result := common.MapStr{}
fields := []struct {
name string
value interface{}
}{
{"country", name.Country},
{"organization", name.Organization},
{"organizational_unit", name.OrganizationalUnit},
{"locality", name.Locality},
{"province", name.Province},
{"postal_code", name.PostalCode},
{"serial_number", name.SerialNumber},
{"common_name", name.CommonName},
{"street_address", name.StreetAddress},
}
for _, field := range fields {
var str string
switch value := field.value.(type) {
case string:
str = value
case []string:
str = strings.Join(value, " ")
}
if len(str) > 0 {
result[field.name] = str
}
}
return result
}
func (parser *parser) hasInfo() bool {
return parser.hello != nil || len(parser.alerts) != 0 || len(parser.certificates) != 0
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/sqos/beats.git
git@gitee.com:sqos/beats.git
sqos
beats
beats
v6.1.2

搜索帮助