diff --git a/pkg/trainer/replicas.go b/pkg/trainer/replicas.go index c3e76ac..0feef4e 100644 --- a/pkg/trainer/replicas.go +++ b/pkg/trainer/replicas.go @@ -211,26 +211,36 @@ func (s *MXReplicaSet) createDist() error { if spec.ContainerName(c.Name) != spec.MXNET { continue } - if len(c.Env) == 0 { - c.Env = make([]v1.EnvVar, 5) + if c.Env == nil { + c.Env = []v1.EnvVar{} } for _, r := range s.Job.job.Spec.ReplicaSpecs { switch r.MxReplicaType { case spec.SCHEDULER: - c.Env[0].Name = "DMLC_PS_ROOT_PORT" - c.Env[0].Value = strconv.Itoa(int(*r.PsRootPort)) - c.Env[1].Name = "DMLC_PS_ROOT_URI" - c.Env[1].Value = fmt.Sprintf("%v-%v-%v-%v", s.Job.job.Metadata.Name, strings.ToLower(string(r.MxReplicaType)), s.Job.job.Spec.RuntimeId, 0) + c.Env = append(c.Env, v1.EnvVar{ + Name: "DMLC_PS_ROOT_PORT", + Value: strconv.Itoa(int(*r.PsRootPort)), + }) + c.Env = append(c.Env, v1.EnvVar{ + Name: "DMLC_PS_ROOT_URI", + Value: fmt.Sprintf("%v-%v-%v-%v", s.Job.job.Metadata.Name, strings.ToLower(string(r.MxReplicaType)), s.Job.job.Spec.RuntimeId, 0), + }) case spec.SERVER: - c.Env[2].Name = "DMLC_NUM_SERVER" - c.Env[2].Value = strconv.Itoa(int(*r.Replicas)) + c.Env = append(c.Env, v1.EnvVar{ + Name: "DMLC_NUM_SERVER", + Value: strconv.Itoa(int(*r.Replicas)), + }) case spec.WORKER: - c.Env[3].Name = "DMLC_NUM_WORKER" - c.Env[3].Value = strconv.Itoa(int(*r.Replicas)) + c.Env = append(c.Env, v1.EnvVar{ + Name: "DMLC_NUM_WORKER", + Value: strconv.Itoa(int(*r.Replicas)), + }) } } - c.Env[4].Name = "DMLC_ROLE" - c.Env[4].Value = strings.ToLower(string(s.Spec.MxReplicaType)) + c.Env = append(c.Env, v1.EnvVar{ + Name: "DMLC_ROLE", + Value: strings.ToLower(string(s.Spec.MxReplicaType)), + }) } log.Infof("Creating Job: %v", newJ.ObjectMeta.Name)