Files
patchnotes/gwen.py

72 lines
2.0 KiB
Python

from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
from sys import argv, stderr as err
import threading
import torch
model_name = "Qwen/Qwen3-8B-FP8"
# 1) Choose device (use CUDA if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device, file=err)
# 2) Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
# If GPU and limited VRAM, consider dtype=torch.float16 for half precision
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float16 if device.type == "cuda" else None,
device_map=device,
)
print("max_length =", tokenizer.model_max_length, file=err)
print("max_embeds =", model.config.max_position_embeddings, file=err)
# 3) Prepare chat inputs (tokenized tensors)
if len(argv) > 1:
prompt = "".join(argv[1:])
else:
prompt = open("prompt").read().strip()
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
num_input_tokens = inputs["input_ids"].shape[1]
tokens = num_input_tokens
print("input tokens =", num_input_tokens, file=err)
# Move input tensors to the same device as the model
inputs = {k: v.to(device) for k, v in inputs.items()}
# 4) Create streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 5) Start generation in background thread (generate is blocking)
gen_kwargs = dict(
**inputs,
max_new_tokens=131072,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
# 6) Consume and display streamed text in real time
for chunk in streamer:
tokens += len(tokenizer.encode(chunk, add_special_tokens=False))
print(chunk, end="", flush=True)
print(tokens, "/131072 of token limit", end="\r", sep="", file=err)
print()
thread.join()
print() # final newline