Skip to content

Commit c213873

Browse files
committed
Add NVIDIA hardware info display to settings with PyTorch version switching capability. Auto-select optimal PyTorch version during initial setup based on detected hardware.
1 parent 7138bf3 commit c213873

File tree

11 files changed

+336
-10
lines changed

11 files changed

+336
-10
lines changed

backend-golang/hw_info.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package backend_golang
2+
3+
import (
4+
"errors"
5+
"os/exec"
6+
"strconv"
7+
"strings"
8+
)
9+
10+
func (a *App) GetNvidiaGpuCount() (int, error) {
11+
// temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used
12+
// gpu_name,gpu_bus_id,driver_version
13+
// nvidia-smi --help-query-gpu
14+
output, err := exec.Command("nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits").CombinedOutput()
15+
if err != nil {
16+
return 0, err
17+
}
18+
return strconv.Atoi(strings.TrimSpace(string(output)))
19+
}
20+
21+
func (a *App) GetCudaComputeCapability(index int) (string, error) {
22+
output, err := exec.Command("nvidia-smi", "-i="+strconv.Itoa(index), "--query-gpu=compute_cap", "--format=csv,noheader,nounits").CombinedOutput()
23+
if err != nil {
24+
return "", err
25+
}
26+
27+
computeCap := strings.TrimSpace(string(output))
28+
if computeCap == "" {
29+
return "", errors.New("compute capability is empty")
30+
}
31+
32+
return computeCap, nil
33+
}
34+
35+
func (a *App) GetMaxCudaComputeCapability() (string, error) {
36+
gpuCount, err := a.GetNvidiaGpuCount()
37+
if err != nil {
38+
return "", err
39+
}
40+
maxComputeCap := "0.0"
41+
for i := 0; i < gpuCount; i++ {
42+
computeCap, err := a.GetCudaComputeCapability(i)
43+
if err != nil {
44+
return "", err
45+
}
46+
computeCapFloat, err := strconv.ParseFloat(computeCap, 64)
47+
if err != nil {
48+
return "", err
49+
}
50+
maxComputeCapFloat, err := strconv.ParseFloat(maxComputeCap, 64)
51+
if err != nil {
52+
return "", err
53+
}
54+
if computeCapFloat > maxComputeCapFloat {
55+
maxComputeCap = computeCap
56+
}
57+
}
58+
if maxComputeCap == "0.0" {
59+
return "", errors.New("no cuda compute capability")
60+
}
61+
return maxComputeCap, nil
62+
}
63+
64+
func (a *App) GetSupportedCudaVersion() (string, error) {
65+
output, err := exec.Command("nvidia-smi", "--query").CombinedOutput()
66+
if err != nil {
67+
return "", err
68+
}
69+
70+
lines := strings.Split(string(output), "\n")
71+
72+
for _, line := range lines {
73+
if strings.Contains(line, "CUDA Version") {
74+
return strings.TrimSpace(strings.Split(line, ":")[1]), nil
75+
}
76+
}
77+
78+
return "", errors.New("cuda version is empty")
79+
}
80+
81+
func (a *App) GetTorchVersion(python string) (string, error) {
82+
var err error
83+
if python == "" {
84+
python, err = a.GetPython()
85+
if err != nil {
86+
return "", err
87+
}
88+
}
89+
90+
output, err := exec.Command(python, "-c", "import torch; print(torch.__version__)").CombinedOutput()
91+
if err != nil {
92+
return "", err
93+
}
94+
95+
version := strings.TrimSpace(string(output))
96+
if version == "" {
97+
return "", errors.New("torch version is empty")
98+
}
99+
100+
return version, nil
101+
}

backend-golang/rwkv.go

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package backend_golang
44
import (
55
"encoding/json"
66
"errors"
7+
"fmt"
78
"os"
89
"os/exec"
910
"runtime"
@@ -221,13 +222,49 @@ func (a *App) MergeLora(python string, useGpu bool, loraAlpha int, baseModel str
221222
return Cmd(args...)
222223
}
223224

224-
func (a *App) InstallPyDep(python string, cnMirror bool) (string, error) {
225+
func (a *App) InstallTorch(python string, cnMirror bool, torchVersion string, cuSourceVersion string) (string, error) {
226+
if runtime.GOOS != "windows" {
227+
return "", errors.New("only support windows")
228+
}
229+
230+
var err error
231+
cuSourceVersion = strings.Replace(cuSourceVersion, ".", "", 1)
232+
torchWhlUrl := fmt.Sprintf("torch==%s --index-url https://download.pytorch.org/whl/cu%s", torchVersion, cuSourceVersion)
233+
if python == "" {
234+
python, err = a.GetPython()
235+
if cnMirror && python == "py310/python.exe" {
236+
torchWhlUrl = fmt.Sprintf("https://mirrors.aliyun.com/pytorch-wheels/cu%s/torch-%s+cu%s-cp310-cp310-win_amd64.whl", cuSourceVersion, torchVersion, cuSourceVersion)
237+
}
238+
if runtime.GOOS == "windows" {
239+
python = `"%CD%/` + python + `"`
240+
}
241+
}
242+
if err != nil {
243+
return "", err
244+
}
245+
246+
a.ChangeFileLine("./py310/python310._pth", 3, "Lib\\site-packages")
247+
installScript := python + " ./backend-python/get-pip.py -i https://mirrors.aliyun.com/pypi/simple --no-warn-script-location\n" +
248+
python + " -m pip install " + torchWhlUrl + " --no-warn-script-location\n" +
249+
"exit"
250+
if !cnMirror {
251+
installScript = strings.Replace(installScript, " -i https://mirrors.aliyun.com/pypi/simple", "", -1)
252+
}
253+
err = os.WriteFile(a.exDir+"install-py-dep.bat", []byte(installScript), 0644)
254+
if err != nil {
255+
return "", err
256+
}
257+
return Cmd("install-py-dep.bat")
258+
}
259+
260+
func (a *App) InstallPyDep(python string, cnMirror bool, torchVersion string, cuSourceVersion string) (string, error) {
225261
var err error
226-
torchWhlUrl := "torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cu117"
262+
cuSourceVersion = strings.Replace(cuSourceVersion, ".", "", 1)
263+
torchWhlUrl := fmt.Sprintf("torch==%s --index-url https://download.pytorch.org/whl/cu%s", torchVersion, cuSourceVersion)
227264
if python == "" {
228265
python, err = a.GetPython()
229266
if cnMirror && python == "py310/python.exe" {
230-
torchWhlUrl = "https://mirrors.aliyun.com/pytorch-wheels/cu117/torch-1.13.1+cu117-cp310-cp310-win_amd64.whl"
267+
torchWhlUrl = fmt.Sprintf("https://mirrors.aliyun.com/pytorch-wheels/cu%s/torch-%s+cu%s-cp310-cp310-win_amd64.whl", cuSourceVersion, torchVersion, cuSourceVersion)
231268
}
232269
if runtime.GOOS == "windows" {
233270
python = `"%CD%/` + python + `"`

frontend/src/_locales/ja/main.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,5 +379,9 @@
379379
"Download File": "ファイルをダウンロード",
380380
"Force Enable Deep Think, Currently Only DeepSeek API and RWKV Runner Server Support": "深い思考を強制的に有効化、現在はDeepSeek APIとRWKV Runnerサーバーのみサポート",
381381
"DeepThink": "深い思考",
382-
"Reasoning": "推論思考"
382+
"Reasoning": "推論思考",
383+
"Pytorch Version": "Pytorchバージョン",
384+
"Not Installed": "未インストール",
385+
"Driver CUDA Version": "ドライバーCUDAバージョン",
386+
"CUDA Compute Capability": "CUDA計算互換性"
383387
}

frontend/src/_locales/zh-hans/main.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,5 +379,9 @@
379379
"Download File": "下载文件",
380380
"Force Enable Deep Think, Currently Only DeepSeek API and RWKV Runner Server Support": "强制启用深度思考, 目前仅 DeepSeek API 及 RWKV Runner 服务器支持",
381381
"DeepThink": "深度思考",
382-
"Reasoning": "推理思考"
382+
"Reasoning": "推理思考",
383+
"Pytorch Version": "Pytorch版本",
384+
"Not Installed": "未安装",
385+
"Driver CUDA Version": "驱动CUDA版本",
386+
"CUDA Compute Capability": "CUDA计算兼容性"
383387
}

frontend/src/pages/Settings.tsx

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,38 @@ import {
66
AccordionPanel,
77
Dropdown,
88
Input,
9+
Label,
910
Option,
1011
Switch,
1112
} from '@fluentui/react-components'
13+
import { compare } from 'compare-versions'
1214
import { observer } from 'mobx-react-lite'
1315
import { useTranslation } from 'react-i18next'
1416
import { toast } from 'react-toastify'
15-
import { RestartApp } from '../../wailsjs/go/backend_golang/App'
17+
import {
18+
GetTorchVersion,
19+
InstallTorch,
20+
RestartApp,
21+
} from '../../wailsjs/go/backend_golang/App'
1622
import { Labeled } from '../components/Labeled'
1723
import { Page } from '../components/Page'
1824
import commonStore from '../stores/commonStore'
1925
import { Language, Languages } from '../types/settings'
2026
import { checkUpdate, toastWithButton } from '../utils'
27+
import {
28+
getAvailableTorchCuVersion,
29+
torchVersions,
30+
} from '../utils/get-available-torch-cu-version'
2131

2232
export const GeneralSettings: FC = observer(() => {
2333
const { t } = useTranslation()
2434

35+
useEffect(() => {
36+
if (commonStore.platform === 'windows' && !commonStore.torchVersion) {
37+
commonStore.refreshTorchVersion()
38+
}
39+
}, [])
40+
2541
return (
2642
<div className="flex flex-col gap-2">
2743
<Labeled
@@ -89,6 +105,63 @@ export const GeneralSettings: FC = observer(() => {
89105
}
90106
/>
91107
)}
108+
{commonStore.platform === 'windows' && (
109+
<Labeled
110+
label={t('Pytorch Version')}
111+
flex
112+
spaceBetween
113+
content={
114+
<Dropdown
115+
style={{ minWidth: 0 }}
116+
listbox={{ style: { minWidth: 'fit-content' } }}
117+
value={commonStore.torchVersion || t('Not Installed')!}
118+
selectedOptions={[
119+
commonStore.torchVersion
120+
? torchVersions.find((v) =>
121+
commonStore.torchVersion.includes(v)
122+
) || ''
123+
: '',
124+
]}
125+
onOptionSelect={(_, data) => {
126+
const selectedVersion = data.optionValue
127+
if (selectedVersion) {
128+
const isSelectingCurrent =
129+
commonStore.torchVersion.includes(selectedVersion)
130+
if (!isSelectingCurrent) {
131+
const { torchVersion, cuSourceVersion } =
132+
getAvailableTorchCuVersion(
133+
selectedVersion,
134+
commonStore.driverCudaVersion || '11.7'
135+
)
136+
InstallTorch(
137+
commonStore.settings.customPythonPath,
138+
commonStore.settings.cnMirror,
139+
torchVersion,
140+
cuSourceVersion
141+
)
142+
.then(() => {
143+
commonStore.refreshTorchVersion()
144+
})
145+
.catch((e) => {
146+
toast.error(e)
147+
})
148+
}
149+
}
150+
}}
151+
>
152+
{torchVersions.map((v) => (
153+
<Option key={v} value={v}>
154+
{v}
155+
</Option>
156+
))}
157+
</Dropdown>
158+
}
159+
/>
160+
)}
161+
{commonStore.platform === 'windows' &&
162+
commonStore.cudaComputeCapability && (
163+
<Label size="small">{`${t('Driver CUDA Version')}: ${commonStore.driverCudaVersion} - ${t('CUDA Compute Capability')}: ${commonStore.cudaComputeCapability}`}</Label>
164+
)}
92165
<Labeled
93166
label={t('Dark Mode')}
94167
flex

frontend/src/startup.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ import { toast } from 'react-toastify'
55
import manifest from '../../manifest.json'
66
import {
77
FileExists,
8+
GetMaxCudaComputeCapability,
89
GetPlatform,
10+
GetSupportedCudaVersion,
11+
GetTorchVersion,
912
ListDirFiles,
1013
ReadJson,
1114
} from '../wailsjs/go/backend_golang/App'
@@ -75,6 +78,15 @@ export async function startup() {
7578
// })
7679
// }
7780
})
81+
setTimeout(() => {
82+
GetMaxCudaComputeCapability().then((c) => {
83+
commonStore.setCudaComputeCapability(c)
84+
})
85+
GetSupportedCudaVersion().then((v) => {
86+
commonStore.setDriverCudaVersion(v)
87+
})
88+
commonStore.refreshTorchVersion()
89+
}, 1000)
7890
}
7991
}
8092

frontend/src/stores/commonStore.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { ChartData } from 'chart.js'
22
import i18n from 'i18next'
33
import { makeAutoObservable } from 'mobx'
44
import manifest from '../../../manifest.json'
5+
import { GetTorchVersion } from '../../wailsjs/go/backend_golang/App'
56
import { WindowSetDarkTheme, WindowSetLightTheme } from '../../wailsjs/runtime'
67
import {
78
defaultCompositionPrompt,
@@ -72,6 +73,9 @@ class CommonStore {
7273
monitorData: MonitorData | null = null
7374
depComplete: boolean = false
7475
platform: Platform = 'windows'
76+
cudaComputeCapability: string = ''
77+
driverCudaVersion: string = ''
78+
torchVersion: string = ''
7579
proxyPort: number = 0
7680
lastModelName: string = ''
7781
stateModels: string[] = []
@@ -371,6 +375,20 @@ class CommonStore {
371375
this.platform = value
372376
}
373377

378+
setCudaComputeCapability(value: string) {
379+
this.cudaComputeCapability = value
380+
}
381+
382+
setDriverCudaVersion(value: string) {
383+
this.driverCudaVersion = value
384+
}
385+
386+
refreshTorchVersion() {
387+
GetTorchVersion(this.settings.customPythonPath).then((v) => {
388+
this.torchVersion = v
389+
})
390+
}
391+
374392
setProxyPort(value: number) {
375393
this.proxyPort = value
376394
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import { compare } from 'compare-versions'
2+
3+
export const torchVersions = ['1.13.1', '2.7.1']
4+
5+
export function getAvailableTorchCuVersion(
6+
torchVersion: string,
7+
driverCudaVersion: string
8+
) {
9+
let retTorchVersion = ''
10+
let retCuSourceVersion = ''
11+
const targetTorchVersion = torchVersion.split('+')[0]
12+
if (compare(targetTorchVersion, '2.7.1', '>=')) {
13+
retTorchVersion = '2.7.1'
14+
if (compare(driverCudaVersion, '12.8', '>=')) {
15+
retCuSourceVersion = '12.8'
16+
} else if (compare(driverCudaVersion, '12.6', '>=')) {
17+
retCuSourceVersion = '12.6'
18+
} else {
19+
retCuSourceVersion = '11.8'
20+
}
21+
} else {
22+
retTorchVersion = '1.13.1'
23+
if (compare(driverCudaVersion, '11.7', '>=')) {
24+
retCuSourceVersion = '11.7'
25+
} else {
26+
retCuSourceVersion = '11.6'
27+
}
28+
}
29+
return { torchVersion: retTorchVersion, cuSourceVersion: retCuSourceVersion }
30+
}

0 commit comments

Comments
 (0)