1 Star 0 Fork 0

h79/goutils

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
scp.go 7.41 KB
一键复制 编辑 原始数据 按行查看 历史
huqiuyun 提交于 2022-10-24 11:57 . rpc ,close connect,git
package ssh
import (
"bufio"
"errors"
"fmt"
"gitee.com/h79/goutils/common"
"gitee.com/h79/goutils/common/file"
"golang.org/x/crypto/ssh"
"io"
"os"
"path"
"strconv"
"strings"
"sync"
"time"
)
const KRemoteCmd = "scp"
type Scp struct {
session *Session
cmd string
}
type Option struct {
Cmd string
Local Path
Remote Path
Permission string
UpdatePermission bool //-p 保留原文件的修改时间,访问时间和访问权限。
Recursive bool //-r
}
func NewScp(session *Session) *Scp {
return &Scp{session: session}
}
func (scp *Scp) WithCmd(cmd string) *Scp {
scp.cmd = cmd
return scp
}
func (scp *Scp) connect() (*ssh.Session, error) {
scp.close()
if err := scp.session.Connect(); err != nil {
return nil, err
}
return scp.session.Session, nil
}
func (scp *Scp) Close() {
scp.close()
}
func (scp *Scp) close() {
scp.session.Close()
}
// SendTo 上传
func (scp *Scp) SendTo(id int, opt *Option) *Result {
dir := file.IsDir(opt.Local.Name)
if dir == 0 {
opt.Local.IsDir = false
//文件名
return scp.sendFile(id, opt)
}
if dir == 1 {
opt.Local.IsDir = true
if opt.Remote.IsDir {
//Remote is dir
return scp.sendDir(id, opt)
} else {
// Remote is file
localFile(opt)
return scp.sendFile(id, opt)
}
}
return nil
}
// 发送单文件
func (scp *Scp) sendFile(id int, opt *Option) *Result {
result := &Result{
Id: id,
Host: scp.session.Host,
LocalPath: opt.Local.Name,
RemotePath: opt.Remote.Name,
}
defer func() {
result.EndTime = time.Now()
}()
start := time.Now()
result.StartTime = start
src, size, err := file.Open(opt.Local.Name)
if err != nil {
result.Error = err
return result
}
defer src.Close()
result.Error = scp.SendToByReader(src, size, opt, remoteFile(opt))
return result
}
// 发送路径文件
func (scp *Scp) sendDir(id int, opt *Option) *Result {
return nil
}
func (scp *Scp) SendToByReader(src io.Reader, size int64, opt *Option, remoteFilename string) error {
scp.send(opt)
return scp.handler(func(session *ssh.Session, out io.Reader, w io.WriteCloser) error {
var (
err error
)
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer common.Recover()
defer wg.Done()
defer w.Close()
var err1 error
if _, err1 = fmt.Fprintln(w, "C"+opt.Permission, size, remoteFilename); err1 != nil {
err = err1
return
}
if err1 = checkResponse(out); err1 != nil {
err = err1
return
}
if _, err1 = io.Copy(w, src); err1 != nil {
err = err1
return
}
if _, err1 = fmt.Fprint(w, "\x00"); err1 != nil {
err = err1
return
}
err = checkResponse(out)
}()
go func() {
defer common.Recover()
defer wg.Done()
err = session.Run(scp.cmd)
if err != nil {
return
}
}()
wg.Wait()
return err
})
}
func (scp *Scp) ReceiveFrom(opt *Option) error {
// Create a local file to write to
dst, err := os.OpenFile(localFile(opt), os.O_RDWR|os.O_CREATE, opt.Local.Mode)
if err != nil {
return err
}
defer dst.Close()
return scp.ReceiveFromByWriter(dst, opt)
}
// ReceiveFromByWriter 下载
func (scp *Scp) ReceiveFromByWriter(dst io.Writer, opt *Option) error {
scp.recv(opt)
return scp.handler(func(session *ssh.Session, out io.Reader, in io.WriteCloser) error {
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer common.Recover()
defer wg.Done()
defer in.Close()
var Ack = func(w io.Writer) error {
var msg = []byte{0}
n, err2 := w.Write(msg)
if err2 != nil {
return err2
}
if n < len(msg) {
return errors.New("failed to write ack buffer")
}
return nil
}
if err1 := session.Start(scp.cmd); err1 != nil {
return
}
if err1 := Ack(in); err1 != nil {
return
}
res, err1 := ParseResponse(out)
if err1 != nil {
return
}
if res.IsFailure() {
return
}
infos, err1 := res.ParseFileInfos()
if err1 != nil {
return
}
err1 = Ack(in)
if err1 != nil {
return
}
_, err1 = file.CopyN(dst, out, infos.Size)
if err1 != nil {
return
}
err1 = Ack(in)
if err1 != nil {
return
}
err1 = session.Wait()
}()
wg.Wait()
return nil
})
}
func (scp *Scp) handler(start func(session *ssh.Session, out io.Reader, in io.WriteCloser) error) error {
session, err := scp.connect()
if err != nil {
return err
}
out, err := session.StdoutPipe()
if err != nil {
return err
}
in, err := session.StdinPipe()
if err != nil {
return err
}
return start(session, out, in)
}
func (scp *Scp) recv(opt *Option) {
if opt.Cmd == "" {
opt.Cmd = KRemoteCmd
}
p := []byte("-f")
if opt.UpdatePermission {
p = append(p, 'p')
}
if opt.Recursive {
p = append(p, 'r')
}
if opt.Remote.IsDir {
p = append(p, 'd')
}
cmd := opt.Cmd + " " + string(p) + " " + opt.Remote.Name
scp.cmd = cmd
}
func (scp *Scp) send(opt *Option) {
if opt.Cmd == "" {
opt.Cmd = KRemoteCmd
}
p := []byte("-qt")
if opt.UpdatePermission {
p = append(p, 'p')
}
if opt.Recursive {
p = append(p, 'r')
}
if opt.Remote.IsDir {
p = append(p, 'd')
}
cmd := opt.Cmd + " " + string(p) + " " + opt.Remote.Name
scp.cmd = cmd
}
func remoteFile(opt *Option) string {
if opt.Remote.IsDir {
filename := path.Base(opt.Local.Name)
opt.Remote.Name = path.Join(opt.Remote.Name, filename)
opt.Remote.IsDir = false
return filename
}
return path.Base(opt.Remote.Name)
}
func localFile(opt *Option) string {
if opt.Local.IsDir {
filename := path.Base(opt.Remote.Name)
opt.Local.Name = path.Join(opt.Local.Name, filename)
opt.Local.IsDir = false
}
return opt.Local.Name
}
func checkResponse(r io.Reader) error {
response, err := ParseResponse(r)
if err != nil {
return err
}
if response.IsFailure() {
return errors.New(response.GetMessage())
}
return nil
}
type ResponseType = uint8
const (
Ok ResponseType = 0
Warning ResponseType = 1
Error ResponseType = 2
)
type Response struct {
Type ResponseType
Message string
}
type FileInfos struct {
Message string
Filename string
Permissions string
Size int64
}
func ParseResponse(reader io.Reader) (Response, error) {
buffer := make([]uint8, 1)
_, err := reader.Read(buffer)
if err != nil {
return Response{}, err
}
responseType := buffer[0]
message := ""
if responseType > 0 {
bufferedReader := bufio.NewReader(reader)
message, err = bufferedReader.ReadString('\n')
if err != nil {
return Response{}, err
}
}
return Response{responseType, message}, nil
}
func (r *Response) IsOk() bool {
return r.Type == Ok
}
func (r *Response) IsWarning() bool {
return r.Type == Warning
}
// IsError returns true when the remote responded with an error.
func (r *Response) IsError() bool {
return r.Type == Error
}
// IsFailure returns true when the remote answered with a warning or an error.
func (r *Response) IsFailure() bool {
return r.IsWarning() || r.IsError()
}
// GetMessage returns the message the remote sent back.
func (r *Response) GetMessage() string {
return r.Message
}
func (r *Response) ParseFileInfos() (*FileInfos, error) {
message := strings.ReplaceAll(r.Message, "\n", "")
parts := strings.Split(message, " ")
if len(parts) < 3 {
return nil, errors.New("unable to parse message as file infos")
}
size, err := strconv.Atoi(parts[1])
if err != nil {
return nil, err
}
return &FileInfos{
Message: r.Message,
Permissions: parts[0],
Size: int64(size),
Filename: parts[2],
}, nil
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Go
1
https://gitee.com/h79/goutils.git
git@gitee.com:h79/goutils.git
h79
goutils
goutils
v1.4.14

搜索帮助