Files
patchnotes/gwen.py
2025-12-30 18:58:23 +01:00

73 lines
1.9 KiB
Python

from sys import stderr as err
from sys import argv
## Streaming generation on GPU (CUDA) with TextIteratorStreamer
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
import threading
import torch
model_name = "Qwen/Qwen3-8B-FP8"
# model_name = "Qwen/Qwen3-8B"
# 1) Choose device (use CUDA if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device, file=err)
# 2) Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
# If GPU and limited VRAM, consider dtype=torch.float16 for half precision
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float16 if device.type == "cuda" else None,
device_map=device)
model.to(device)
print("tokenizer.model_max_length =", tokenizer.model_max_length, file=err)
print("model.config.max_position_embeddings =", model.config.max_position_embeddings, file=err)
# 3) Prepare chat inputs (tokenized tensors)
if len(argv) > 1:
prompt = "".join(argv[1:])
else:
prompt = open("prompt").read().strip()
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
# Move input tensors to the same device as the model
inputs = {k: v.to(device) for k, v in inputs.items()}
# 4) Create streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 5) Start generation in background thread (generate is blocking)
gen_kwargs = dict(
**inputs,
max_new_tokens=32768,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
# 6) Consume and display streamed text in real time
generated_text = ""
for chunk in streamer:
generated_text += chunk
print(chunk, end="", flush=True)
thread.join()
print() # final newline