Compare commits

...

1 Commits

Author SHA1 Message Date
dkeven
2447fdd74e fix(cli): unify node GPU info update logic 2025-12-22 16:31:03 +08:00
3 changed files with 90 additions and 70 deletions

View File

@@ -187,7 +187,7 @@ func (m *InstallPluginModule) Init() {
Prepare: &prepare.PrepareCollection{ Prepare: &prepare.PrepareCollection{
new(common.OnlyFirstMaster), new(common.OnlyFirstMaster),
}, },
Action: new(UpdateNodeLabels), Action: new(UpdateNodeGPUInfo),
Parallel: false, Parallel: false,
Retry: 1, Retry: 1,
} }
@@ -223,23 +223,6 @@ func (m *InstallPluginModule) Init() {
} }
} }
type GetCudaVersionModule struct {
common.KubeModule
}
func (g *GetCudaVersionModule) Init() {
g.Name = "GetCudaVersion"
getCudaVersion := &task.LocalTask{
Name: "GetCudaVersion",
Action: new(GetCudaVersion),
}
g.Tasks = []task.Interface{
getCudaVersion,
}
}
type NodeLabelingModule struct { type NodeLabelingModule struct {
common.KubeModule common.KubeModule
} }
@@ -253,7 +236,7 @@ func (l *NodeLabelingModule) Init() {
new(CudaInstalled), new(CudaInstalled),
new(CurrentNodeInK8s), new(CurrentNodeInK8s),
}, },
Action: new(UpdateNodeLabels), Action: new(UpdateNodeGPUInfo),
Retry: 1, Retry: 1,
} }

View File

@@ -10,7 +10,10 @@ import (
"strings" "strings"
"time" "time"
v1alpha1 "bytetrade.io/web3os/app-service/api/sys.bytetrade.io/v1alpha1"
apputils "bytetrade.io/web3os/app-service/pkg/utils"
ctrl "sigs.k8s.io/controller-runtime" ctrl "sigs.k8s.io/controller-runtime"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
"github.com/beclab/Olares/cli/pkg/clientset" "github.com/beclab/Olares/cli/pkg/clientset"
"github.com/beclab/Olares/cli/pkg/common" "github.com/beclab/Olares/cli/pkg/common"
@@ -26,7 +29,11 @@ import (
"github.com/pelletier/go-toml" "github.com/pelletier/go-toml"
"github.com/pkg/errors" "github.com/pkg/errors"
apixclientset "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
kruntime "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
"k8s.io/client-go/util/retry" "k8s.io/client-go/util/retry"
) )
@@ -323,59 +330,11 @@ func (t *CheckGpuStatus) Execute(runtime connector.Runtime) error {
return fmt.Errorf("GPU Container State is Pending") return fmt.Errorf("GPU Container State is Pending")
} }
type GetCudaVersion struct { type UpdateNodeGPUInfo struct {
common.KubeAction common.KubeAction
} }
func (g *GetCudaVersion) Execute(runtime connector.Runtime) error { func (u *UpdateNodeGPUInfo) Execute(runtime connector.Runtime) error {
var nvidiaSmiFile string
var systemInfo = runtime.GetSystemInfo()
switch {
case systemInfo.IsWsl():
nvidiaSmiFile = "/usr/lib/wsl/lib/nvidia-smi"
default:
nvidiaSmiFile = "/usr/bin/nvidia-smi"
}
if !util.IsExist(nvidiaSmiFile) {
logger.Info("nvidia-smi not exists")
return nil
}
var cudaVersion string
res, err := runtime.GetRunner().Cmd(fmt.Sprintf("%s --version", nvidiaSmiFile), false, true)
if err != nil {
logger.Errorf("get cuda version error %v", err)
return nil
}
lines := strings.Split(res, "\n")
if len(lines) == 0 {
return nil
}
for _, line := range lines {
if strings.Contains(line, "CUDA Version") {
parts := strings.Split(line, ":")
if len(parts) != 2 {
break
}
cudaVersion = strings.TrimSpace(parts[1])
}
}
if cudaVersion != "" {
common.SetSystemEnv("OLARES_SYSTEM_CUDA_VERSION", cudaVersion)
}
return nil
}
type UpdateNodeLabels struct {
common.KubeAction
}
func (u *UpdateNodeLabels) Execute(runtime connector.Runtime) error {
client, err := clientset.NewKubeClient() client, err := clientset.NewKubeClient()
if err != nil { if err != nil {
return errors.Wrap(errors.WithStack(err), "kubeclient create error") return errors.Wrap(errors.WithStack(err), "kubeclient create error")
@@ -482,6 +441,85 @@ func UpdateNodeGpuLabel(ctx context.Context, client kubernetes.Interface, driver
} }
} }
if cuda != nil && *cuda != "" {
if err := updateCudaVersionSystemEnv(ctx, *cuda); err != nil {
logger.Errorf("failed to update SystemEnv for CUDA version: %v", err)
return err
}
}
return nil
}
func updateCudaVersionSystemEnv(ctx context.Context, cudaVersion string) error {
envName := "OLARES_SYSTEM_CUDA_VERSION"
common.SetSystemEnv(envName, cudaVersion)
config, err := ctrl.GetConfig()
if err != nil {
return fmt.Errorf("failed to get rest config: %w", err)
}
apix, err := apixclientset.NewForConfig(config)
if err != nil {
return fmt.Errorf("failed to create crd client: %w", err)
}
_, err = apix.ApiextensionsV1().CustomResourceDefinitions().Get(ctx, "systemenvs.sys.bytetrade.io", metav1.GetOptions{})
if err != nil {
if apierrors.IsNotFound(err) {
logger.Debugf("SystemEnv CRD not found, skipping CUDA version update")
return nil
}
return fmt.Errorf("failed to get SystemEnv CRD: %w", err)
}
scheme := kruntime.NewScheme()
if err := v1alpha1.AddToScheme(scheme); err != nil {
return fmt.Errorf("failed to add systemenv scheme: %w", err)
}
c, err := ctrlclient.New(config, ctrlclient.Options{Scheme: scheme})
if err != nil {
return fmt.Errorf("failed to create client: %w", err)
}
resourceName, err := apputils.EnvNameToResourceName(envName)
if err != nil {
return fmt.Errorf("invalid system env name: %s", envName)
}
var existingSystemEnv v1alpha1.SystemEnv
err = c.Get(ctx, types.NamespacedName{Name: resourceName}, &existingSystemEnv)
if err == nil {
if existingSystemEnv.Default != cudaVersion {
existingSystemEnv.Default = cudaVersion
if err := c.Update(ctx, &existingSystemEnv); err != nil {
return fmt.Errorf("failed to update SystemEnv %s: %w", resourceName, err)
}
logger.Infof("Updated SystemEnv %s default to %s", resourceName, cudaVersion)
}
return nil
}
if !apierrors.IsNotFound(err) {
return fmt.Errorf("failed to get SystemEnv %s: %w", resourceName, err)
}
systemEnv := &v1alpha1.SystemEnv{
ObjectMeta: metav1.ObjectMeta{
Name: resourceName,
},
EnvVarSpec: v1alpha1.EnvVarSpec{
EnvName: envName,
Default: cudaVersion,
},
}
if err := c.Create(ctx, systemEnv); err != nil && !apierrors.IsAlreadyExists(err) {
return fmt.Errorf("failed to create SystemEnv %s: %w", resourceName, err)
}
logger.Infof("Created SystemEnv: %s with default %s", envName, cudaVersion)
return nil return nil
} }

View File

@@ -58,7 +58,6 @@ func (l *linuxInstallPhaseBuilder) installGpuPlugin() phase {
return []module.Module{ return []module.Module{
&gpu.RestartK3sServiceModule{Skip: !(l.runtime.Arg.Kubetype == common.K3s)}, &gpu.RestartK3sServiceModule{Skip: !(l.runtime.Arg.Kubetype == common.K3s)},
&gpu.InstallPluginModule{Skip: skipGpuPlugin}, &gpu.InstallPluginModule{Skip: skipGpuPlugin},
&gpu.GetCudaVersionModule{},
} }
} }