initial commit
This commit is contained in:
72
gwen.py
Normal file
72
gwen.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from sys import stderr as err
|
||||
from sys import argv
|
||||
|
||||
## Streaming generation on GPU (CUDA) with TextIteratorStreamer
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
TextIteratorStreamer,
|
||||
)
|
||||
import threading
|
||||
import torch
|
||||
|
||||
model_name = "Qwen/Qwen3-8B-FP8"
|
||||
# model_name = "Qwen/Qwen3-8B"
|
||||
|
||||
# 1) Choose device (use CUDA if available)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Using device:", device, file=err)
|
||||
|
||||
# 2) Load tokenizer and model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# If GPU and limited VRAM, consider dtype=torch.float16 for half precision
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
dtype=torch.float16 if device.type == "cuda" else None,
|
||||
device_map=device)
|
||||
model.to(device)
|
||||
|
||||
print("tokenizer.model_max_length =", tokenizer.model_max_length, file=err)
|
||||
print("model.config.max_position_embeddings =", model.config.max_position_embeddings, file=err)
|
||||
|
||||
# 3) Prepare chat inputs (tokenized tensors)
|
||||
if len(argv) > 1:
|
||||
prompt = "".join(argv[1:])
|
||||
else:
|
||||
prompt = open("prompt").read().strip()
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move input tensors to the same device as the model
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# 4) Create streamer
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# 5) Start generation in background thread (generate is blocking)
|
||||
gen_kwargs = dict(
|
||||
**inputs,
|
||||
max_new_tokens=32768,
|
||||
streamer=streamer,
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
# 6) Consume and display streamed text in real time
|
||||
generated_text = ""
|
||||
for chunk in streamer:
|
||||
generated_text += chunk
|
||||
print(chunk, end="", flush=True)
|
||||
|
||||
thread.join()
|
||||
print() # final newline
|
||||
|
||||
Reference in New Issue
Block a user