73 lines
1.9 KiB
Python
73 lines
1.9 KiB
Python
from sys import stderr as err
|
|
from sys import argv
|
|
|
|
## Streaming generation on GPU (CUDA) with TextIteratorStreamer
|
|
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
TextIteratorStreamer,
|
|
)
|
|
import threading
|
|
import torch
|
|
|
|
model_name = "Qwen/Qwen3-8B-FP8"
|
|
# model_name = "Qwen/Qwen3-8B"
|
|
|
|
# 1) Choose device (use CUDA if available)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print("Using device:", device, file=err)
|
|
|
|
# 2) Load tokenizer and model
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
# If GPU and limited VRAM, consider dtype=torch.float16 for half precision
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
dtype=torch.float16 if device.type == "cuda" else None,
|
|
device_map=device)
|
|
model.to(device)
|
|
|
|
print("tokenizer.model_max_length =", tokenizer.model_max_length, file=err)
|
|
print("model.config.max_position_embeddings =", model.config.max_position_embeddings, file=err)
|
|
|
|
# 3) Prepare chat inputs (tokenized tensors)
|
|
if len(argv) > 1:
|
|
prompt = "".join(argv[1:])
|
|
else:
|
|
prompt = open("prompt").read().strip()
|
|
messages = [{"role": "user", "content": prompt}]
|
|
inputs = tokenizer.apply_chat_template(
|
|
messages,
|
|
add_generation_prompt=True,
|
|
tokenize=True,
|
|
return_dict=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
# Move input tensors to the same device as the model
|
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
# 4) Create streamer
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
# 5) Start generation in background thread (generate is blocking)
|
|
gen_kwargs = dict(
|
|
**inputs,
|
|
max_new_tokens=32768,
|
|
streamer=streamer,
|
|
)
|
|
|
|
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
|
thread.start()
|
|
|
|
# 6) Consume and display streamed text in real time
|
|
generated_text = ""
|
|
for chunk in streamer:
|
|
generated_text += chunk
|
|
print(chunk, end="", flush=True)
|
|
|
|
thread.join()
|
|
print() # final newline
|
|
|