代码拉取完成,页面将自动刷新
// Copyright 2021 The Kubeflow Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License
package xgboost
import (
"fmt"
"strconv"
"strings"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
kubeflowv1 "gitee.com/vak80/training-operator/pkg/apis/kubeflow.org/v1"
)
// SetPodEnv sets the pod env set for:
// - XGBoost Rabit Tracker and worker
// - LightGBM master and workers
func SetPodEnv(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
xgboostjob, ok := job.(*kubeflowv1.XGBoostJob)
if !ok {
return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob)
}
rank, err := strconv.Atoi(index)
if err != nil {
return err
}
// Add master offset for worker pods
if strings.EqualFold(strings.ToLower(rtype), strings.ToLower(string(kubeflowv1.XGBoostJobReplicaTypeWorker))) {
masterSpec := xgboostjob.Spec.XGBReplicaSpecs[kubeflowv1.XGBoostJobReplicaTypeMaster]
masterReplicas := int(*masterSpec.Replicas)
rank += masterReplicas
}
masterAddr := replicaName(xgboostjob.Name, kubeflowv1.XGBoostJobReplicaTypeMaster, 0)
masterPort, err := getPortFromXGBoostJob(xgboostjob, kubeflowv1.XGBoostJobReplicaTypeMaster)
if err != nil {
return err
}
totalReplicas := computeTotalReplicas(xgboostjob)
var workerPort int32
var workerAddrs []string
if totalReplicas > 1 {
workerPortTemp, err := getPortFromXGBoostJob(xgboostjob, kubeflowv1.XGBoostJobReplicaTypeWorker)
if err != nil {
return err
}
workerPort = workerPortTemp
workerAddrs = make([]string, totalReplicas-1)
for i := range workerAddrs {
workerAddrs[i] = replicaName(xgboostjob.Name, kubeflowv1.XGBoostJobReplicaTypeWorker, i)
}
}
for i := range podTemplate.Spec.Containers {
if len(podTemplate.Spec.Containers[i].Env) == 0 {
podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
}
podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
Name: "MASTER_PORT",
Value: strconv.Itoa(int(masterPort)),
})
podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
Name: "MASTER_ADDR",
Value: masterAddr,
})
podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
Name: "WORLD_SIZE",
Value: strconv.Itoa(int(totalReplicas)),
})
podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
Name: "RANK",
Value: strconv.Itoa(rank),
})
podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
Name: "PYTHONUNBUFFERED",
Value: "1",
})
// This variables are used if it is a LightGBM job
if totalReplicas > 1 {
podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
Name: "WORKER_PORT",
Value: strconv.Itoa(int(workerPort)),
})
podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
Name: "WORKER_ADDRS",
Value: strings.Join(workerAddrs, ","),
})
}
}
return nil
}
func replicaName(jobName string, rtype kubeflowv1.ReplicaType, index int) string {
n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + strconv.Itoa(index)
return strings.Replace(n, "/", "-", -1)
}
// getPortFromXGBoostJob gets the port of xgboost container.
func getPortFromXGBoostJob(job *kubeflowv1.XGBoostJob, rtype kubeflowv1.ReplicaType) (int32, error) {
containers := job.Spec.XGBReplicaSpecs[rtype].Template.Spec.Containers
for _, container := range containers {
if container.Name == kubeflowv1.XGBoostJobDefaultContainerName {
ports := container.Ports
for _, port := range ports {
if port.Name == kubeflowv1.XGBoostJobDefaultPortName {
return port.ContainerPort, nil
}
}
}
}
return -1, fmt.Errorf("failed to found the port")
}
func computeTotalReplicas(obj metav1.Object) int32 {
job := obj.(*kubeflowv1.XGBoostJob)
jobReplicas := int32(0)
if job.Spec.XGBReplicaSpecs == nil || len(job.Spec.XGBReplicaSpecs) == 0 {
return jobReplicas
}
for _, r := range job.Spec.XGBReplicaSpecs {
if r.Replicas == nil {
continue
} else {
jobReplicas += *r.Replicas
}
}
return jobReplicas
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。