08/01/2026 patch
This commit is contained in:
14
gwen.py
14
gwen.py
@@ -1,8 +1,6 @@
|
||||
from sys import stderr as err
|
||||
from sys import argv
|
||||
|
||||
## Streaming generation on GPU (CUDA) with TextIteratorStreamer
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
@@ -12,7 +10,6 @@ import threading
|
||||
import torch
|
||||
|
||||
model_name = "Qwen/Qwen3-8B-FP8"
|
||||
# model_name = "Qwen/Qwen3-8B"
|
||||
|
||||
# 1) Choose device (use CUDA if available)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -26,7 +23,6 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
dtype=torch.float16 if device.type == "cuda" else None,
|
||||
device_map=device)
|
||||
model.to(device)
|
||||
|
||||
print("tokenizer.model_max_length =", tokenizer.model_max_length, file=err)
|
||||
print("model.config.max_position_embeddings =", model.config.max_position_embeddings, file=err)
|
||||
@@ -45,6 +41,10 @@ inputs = tokenizer.apply_chat_template(
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
num_input_tokens = inputs["input_ids"].shape[1]
|
||||
tokens = num_input_tokens
|
||||
print("input tokens =", num_input_tokens, file=err)
|
||||
|
||||
# Move input tensors to the same device as the model
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
@@ -62,11 +62,11 @@ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
# 6) Consume and display streamed text in real time
|
||||
generated_text = ""
|
||||
for chunk in streamer:
|
||||
generated_text += chunk
|
||||
tokens += len(tokenizer.encode(chunk, add_special_tokens=False))
|
||||
print(chunk, end="", flush=True)
|
||||
print(tokens, "/131072 of token limit", end="\r", sep="", file=err)
|
||||
print()
|
||||
|
||||
thread.join()
|
||||
print() # final newline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user