diff --git a/runtime/main.go b/runtime/main.go index b8775d16326302e6fdc5ca32b75e2b6d805bd9c9..0b66449eec4c7465483d9776b0961e12d62bad3b 100644 --- a/runtime/main.go +++ b/runtime/main.go @@ -17,10 +17,13 @@ package main import ( "context" + "crypto/tls" "encoding/json" + "errors" "fmt" "io/ioutil" "log" + "net/http" "os" "os/exec" "path" @@ -34,6 +37,7 @@ import ( "github.com/containerd/containerd/oci" "github.com/opencontainers/runtime-spec/specs-go" "huawei.com/npu-exporter/v5/common-utils/hwlog" + "k8s.io/api/core/v1" "main/dcmi" "mindxcheckutils" @@ -65,6 +69,8 @@ var ( hookDefaultFile = hookDefaultFilePath dockerRuncName = dockerRuncFile runcName = runcFile + + notMatchError = errors.New("container not match pod or pod not has huawei.com/Ascend910 annotation") ) const ( @@ -292,6 +298,26 @@ func removeDuplication(devices []int) []int { return list } +func parseAnnotationDevices(annotationDevices string) ([]int, error) { + devices := make([]int, 0) + + for _, d := range strings.Split(annotationDevices, ",") { + borders := strings.Split(d, Ascend910+"-") + if len(borders) != borderNum || borders[0] != "" { + return nil, fmt.Errorf("invalid device range: %s", d) + } + deviceID, err := strconv.Atoi(borders[1]) + if err != nil { + return nil, fmt.Errorf("invalid device ID: %s", d) + } + + devices = append(devices, deviceID) + } + + sort.Slice(devices, func(i, j int) bool { return i < j }) + return removeDuplication(devices), nil +} + func parseDevices(visibleDevices string) ([]int, error) { devices := make([]int, 0) const maxDevice = 128 @@ -472,6 +498,44 @@ func addManagerDevice(spec *specs.Spec) error { } func addDevice(spec *specs.Spec) error { + // 获取对应pod annotation中的设备信息 + annotationDevices, err := getDeviceFromPod(spec) + if err != nil && err != notMatchError { + // 报错不可直接返回,记录日志即可 + hwlog.RunLog.Errorf("getDeviceFromPod failed: %#v", err) + } + + // 如果没有匹配到pod或annotation,则通过环境变量挂载设备 + if annotationDevices == "" { + hwlog.RunLog.Info("add devices from env variable") + if err = addDeviceFromEnv(spec); err != nil { + return fmt.Errorf("failed to add device to env: %#v", err) + } + return nil + } + + // 如果对应pod annotation中的设备信息存在,则用这个信息挂载设备 + devices, err := parseAnnotationDevices(annotationDevices) + if err != nil { + return fmt.Errorf("failed to parse device: %#v", err) + } + hwlog.RunLog.Infof("annotation devices is: %#v", devices) + deviceName := davinciName + for _, deviceId := range devices { + dPath := devicePath + deviceName + strconv.Itoa(deviceId) + if err = addDeviceToSpec(spec, dPath, deviceName); err != nil { + return fmt.Errorf("failed to add davinci device to spec: %#v", err) + } + } + + if err = addManagerDevice(spec); err != nil { + return fmt.Errorf("failed to add Manager device to spec: %#v", err) + } + + return nil +} + +func addDeviceFromEnv(spec *specs.Spec) error { visibleDevices := getValueByKey(spec.Process.Env, ascendVisibleDevices) if visibleDevices == "" { return nil @@ -650,3 +714,61 @@ func main() { log.Fatal(err) } } + +func getDeviceFromPod(spec *specs.Spec) (string, error) { + // 只有apiserver会访问kubelet的https api接口,所以使用apiserver的客户端证书;证书需要从master节点拷贝到worker节点 + certFile := "/etc/kubernetes/pki/apiserver-kubelet-client.crt" + keyFile := "/etc/kubernetes/pki/apiserver-kubelet-client.key" + kubeletUrl := "https://127.0.0.1:10250/" + podsUrlPath := "pods" + npu910CardName := "huawei.com/Ascend910" + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + hwlog.RunLog.Errorf("LoadX509KeyPair failed: %#v", err) + return "", err + } + + // 构造带客户端证书的http客户端 + client := &http.Client{ + Transport: &http.Transport{ + Proxy: nil, // 禁用代理 + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, // kubelet 是自签名ca证书,apiserver也未校验kubelet服务端证书,所以这里不校验 + }, + }, + } + + // 向kubelet服务端请求获取pod list + resp, err := client.Get(kubeletUrl + podsUrlPath) + if err != nil { + hwlog.RunLog.Errorf("http get failed: %#v", err) + return "", err + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + hwlog.RunLog.Errorf("ReadAll resp.Body failed: %#v", err) + return "", err + } + hwlog.RunLog.Infof("get pod list success, resp.Status: %#v", resp.Status) + + // 遍历pod list,找到此容器 + var podList v1.PodList + if err := json.Unmarshal(body, &podList); err != nil { + hwlog.RunLog.Errorf("unmarshal body failed: %#v", err) + return "", err + } + + for _, pod := range podList.Items { + if pod.ObjectMeta.Name == spec.Hostname { + if value, ok := pod.ObjectMeta.Annotations[npu910CardName]; ok { + return value, nil + } + break + } + } + return "", notMatchError +}7 \ No newline at end of file