Skip to content

Add auto-gptq integration #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make torch installation optional
PanQiWei committed Apr 27, 2023
commit d2b413c2f6107cf4409a566e9aa91d87919ff093
39 changes: 36 additions & 3 deletions setup_env.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,10 @@

"""WARNING: this scripts may only works on linux"""

# change version based on your situation
pip_torch = "torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116",

pip_dependencies = [
# change version based on your situation
"torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116",
"transformers==4.25.1",
"sentencepiece",
"datasets",
@@ -25,6 +25,7 @@ def setup_env():
parser.add_argument("--init_conda", action="store_true")
parser.add_argument("--conda_name", type=str, default="moss")
parser.add_argument("--python_version", type=str, default="3.10")
parser.add_argument("--reinstall_torch", action="store_true")
args = parser.parse_args()

if args.init_conda:
@@ -36,10 +37,42 @@ def setup_env():
cwd=args.conda_home
).stdout.decode()
)

try:
import torch
except ImportError:
print(
subprocess.run(
f"./conda run -n {args.conda_name} pip install {pip_torch}".split(),
check=True,
stdout=subprocess.PIPE,
cwd=args.conda_home
).stdout.decode()
)
args.reinstall_torch = False

if args.reinstall_torch:
print(
subprocess.run(
f"./conda run -n {args.conda_name} pip uninstall torch -y".split(),
check=True,
stdout=subprocess.PIPE,
cwd=args.conda_home
).stdout.decode()
)
print(
subprocess.run(
f"./conda run -n {args.conda_name} pip install {pip_torch}".split(),
check=True,
stdout=subprocess.PIPE,
cwd=args.conda_home
).stdout.decode()
)

for pip_dependency in pip_dependencies:
print(
subprocess.run(
f"./conda run -n {args.conda_name} pip install -U {pip_dependency}".split(),
f"./conda run -n {args.conda_name} pip install {pip_dependency}".split(),
check=True,
stdout=subprocess.PIPE,
cwd=args.conda_home