reach-vb HF Staff commited on
Commit
adf5a8e
·
verified ·
1 Parent(s): 1cfc710
Files changed (1) hide show
  1. app.py +112 -43
app.py CHANGED
@@ -7,8 +7,8 @@ from transformers import pipeline
7
  import spaces
8
 
9
  # === Config (override via Space secrets/env vars) ===
10
- MODEL_ID = os.environ.get("MODEL_ID", "openai/gpt-oss-20b")
11
- STATIC_PROMPT = """"""
12
  DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512))
13
  DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7))
14
  DEFAULT_TOP_P = float(os.environ.get("TOP_P", 0.95))
@@ -16,30 +16,72 @@ DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
16
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
17
 
18
  _pipe = None # cached pipeline
19
- _tok = None # tokenizer for parsing Harmony format
20
 
21
 
 
 
 
 
22
  def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
23
- messages = []
24
  if policy.strip():
25
- messages.append({"role": "system", "content": policy.strip()})
26
- # if STATIC_PROMPT:
27
- # messages.append({"role": "system", "content": STATIC_PROMPT})
28
- messages.append({"role": "user", "content": user_prompt})
29
- return messages
30
-
31
-
32
- def _parse_harmony_output(last, tokenizer):
33
- analysis, content = None, None
34
- if isinstance(last, dict) and ("content" in last or "thinking" in last):
35
- analysis = last.get("thinking")
36
- content = last.get("content")
37
- else:
38
- parsed = tokenizer.parse_response(last)
39
- analysis = parsed.get("thinking")
40
- content = parsed.get("content")
41
- return analysis, content
42
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  @spaces.GPU(duration=ZGPU_DURATION)
45
  def generate_long_prompt(
@@ -50,7 +92,7 @@ def generate_long_prompt(
50
  top_p: float,
51
  repetition_penalty: float,
52
  ) -> Tuple[str, str, str]:
53
- global _pipe, _tok
54
  start = time.time()
55
 
56
  if _pipe is None:
@@ -60,7 +102,6 @@ def generate_long_prompt(
60
  torch_dtype="auto",
61
  device_map="auto",
62
  )
63
- _tok = _pipe.tokenizer
64
 
65
  messages = _to_messages(policy, prompt)
66
 
@@ -72,42 +113,54 @@ def generate_long_prompt(
72
  top_p=top_p,
73
  repetition_penalty=repetition_penalty,
74
  )
75
- print(outputs)
76
- res = outputs[0]
77
- print(res)
78
- last = res.get("generated_text", [])
79
- print(last)
80
- if isinstance(last, list) and last:
81
- last = last[-1]
82
 
83
- analysis, content = _parse_harmony_output(last, _tok)
 
84
 
85
  elapsed = time.time() - start
86
  meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
87
- return analysis or "(No analysis)", content or "(No answer)", meta
 
88
 
 
 
 
89
 
90
- CUSTOM_CSS = "/** Simple styling **/\n.gradio-container {font-family: ui-sans-serif, system-ui, Inter, Roboto;}\ntextarea {font-family: ui-monospace, monospace;}\nfooter {display:none;}"
91
 
92
  with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
93
- gr.Markdown("""# GPT‑OSS Harmony Demo\nProvide a **Policy**, a **Prompt**, and see both **Analysis** and **Answer** separately.""")
 
 
 
 
94
 
95
  with gr.Row():
96
- with gr.Column(scale=1):
97
- policy = gr.Textbox(label="Policy (system)", lines=20, placeholder="Enter the guiding rules and tone…")
98
- prompt = gr.Textbox(label="Prompt (user)", lines=10, placeholder="Enter your main prompt…")
 
 
 
 
 
 
 
 
99
  with gr.Accordion("Advanced settings", open=False):
100
  max_new_tokens = gr.Slider(16, 4096, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="max_new_tokens")
101
  temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="temperature")
102
  top_p = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, step=0.01, label="top_p")
103
  repetition_penalty = gr.Slider(0.8, 2.0, value=DEFAULT_REPETITION_PENALTY, step=0.05, label="repetition_penalty")
104
- generate = gr.Button("Generate", variant="primary")
105
- with gr.Column(scale=1):
106
- analysis = gr.Textbox(label="Analysis (Harmony thinking)", lines=10)
107
- answer = gr.Textbox(label="Answer", lines=10)
 
 
108
  meta = gr.Markdown()
109
 
110
- generate.click(
111
  fn=generate_long_prompt,
112
  inputs=[policy, prompt, max_new_tokens, temperature, top_p, repetition_penalty],
113
  outputs=[analysis, answer, meta],
@@ -115,5 +168,21 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
115
  api_name="generate",
116
  )
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if __name__ == "__main__":
119
  demo.queue(max_size=32).launch()
 
7
  import spaces
8
 
9
  # === Config (override via Space secrets/env vars) ===
10
+ MODEL_ID = os.environ.get("MODEL_ID", "tlhv/osb-minier")
11
+ STATIC_PROMPT = "" # optional extra system message
12
  DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512))
13
  DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7))
14
  DEFAULT_TOP_P = float(os.environ.get("TOP_P", 0.95))
 
16
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
17
 
18
  _pipe = None # cached pipeline
 
19
 
20
 
21
+ # ----------------------------
22
+ # Helpers (simple & explicit)
23
+ # ----------------------------
24
+
25
  def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
26
+ msgs: List[Dict[str, str]] = []
27
  if policy.strip():
28
+ msgs.append({"role": "system", "content": policy.strip()})
29
+ if STATIC_PROMPT:
30
+ msgs.append({"role": "system", "content": STATIC_PROMPT})
31
+ msgs.append({"role": "user", "content": user_prompt})
32
+ return msgs
33
+
34
+
35
+ def _extract_assistant_content(outputs) -> str:
36
+ """Extract the assistant's content from the known shape:
37
+ outputs = [
38
+ {
39
+ 'generated_text': [
40
+ {'role': 'system', 'content': ...},
41
+ {'role': 'user', 'content': ...},
42
+ {'role': 'assistant', 'content': 'analysis...assistantfinal...'}
43
+ ]
44
+ }
45
+ ]
46
+ Keep this forgiving and minimal.
47
+ """
48
+ try:
49
+ msgs = outputs[0]["generated_text"]
50
+ for m in reversed(msgs):
51
+ if isinstance(m, dict) and m.get("role") == "assistant":
52
+ return m.get("content", "")
53
+ last = msgs[-1]
54
+ return last.get("content", "") if isinstance(last, dict) else str(last)
55
+ except Exception:
56
+ return str(outputs)
57
+
58
+
59
+ def _parse_harmony_output_from_string(s: str) -> Tuple[str, str]:
60
+ """Split a Harmony-style concatenated string into (analysis, final).
61
+ Expects markers 'analysis' ... 'assistantfinal'.
62
+ No heavy parsing — just string finds.
63
+ """
64
+ if not isinstance(s, str):
65
+ s = str(s)
66
+ final_key = "assistantfinal"
67
+ j = s.find(final_key)
68
+ if j != -1:
69
+ final_text = s[j + len(final_key):].strip()
70
+ i = s.find("analysis")
71
+ if i != -1 and i < j:
72
+ analysis_text = s[i + len("analysis"): j].strip()
73
+ else:
74
+ analysis_text = s[:j].strip()
75
+ return analysis_text, final_text
76
+ # no explicit final marker
77
+ if s.startswith("analysis"):
78
+ return s[len("analysis"):].strip(), ""
79
+ return "", s.strip()
80
+
81
+
82
+ # ----------------------------
83
+ # Inference
84
+ # ----------------------------
85
 
86
  @spaces.GPU(duration=ZGPU_DURATION)
87
  def generate_long_prompt(
 
92
  top_p: float,
93
  repetition_penalty: float,
94
  ) -> Tuple[str, str, str]:
95
+ global _pipe
96
  start = time.time()
97
 
98
  if _pipe is None:
 
102
  torch_dtype="auto",
103
  device_map="auto",
104
  )
 
105
 
106
  messages = _to_messages(policy, prompt)
107
 
 
113
  top_p=top_p,
114
  repetition_penalty=repetition_penalty,
115
  )
 
 
 
 
 
 
 
116
 
117
+ assistant_str = _extract_assistant_content(outputs)
118
+ analysis_text, final_text = _parse_harmony_output_from_string(assistant_str)
119
 
120
  elapsed = time.time() - start
121
  meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
122
+ return analysis_text or "(No analysis)", final_text or "(No answer)", meta
123
+
124
 
125
+ # ----------------------------
126
+ # UI
127
+ # ----------------------------
128
 
129
+ CUSTOM_CSS = "/** Pretty but simple **/\n:root { --radius: 14px; }\n.gradio-container { font-family: ui-sans-serif, system-ui, Inter, Roboto, Arial; }\n#hdr h1 { font-weight: 700; letter-spacing: -0.02em; }\ntextarea { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, 'Liberation Mono', 'Courier New', monospace; }\nfooter { display:none; }\n"
130
 
131
  with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
132
+ with gr.Column(elem_id="hdr"):
133
+ gr.Markdown("""
134
+ # GPT‑OSS (ZeroGPU) — Harmony View
135
+ Provide a **Policy**, a **Prompt**, and see **Analysis** and **Answer** separately.
136
+ """)
137
 
138
  with gr.Row():
139
+ with gr.Column(scale=1, min_width=380):
140
+ policy = gr.Textbox(
141
+ label="Policy (system)",
142
+ lines=20, # bigger than prompt
143
+ placeholder="Rules, tone, and constraints…",
144
+ )
145
+ prompt = gr.Textbox(
146
+ label="Prompt (user)",
147
+ lines=10,
148
+ placeholder="Your request…",
149
+ )
150
  with gr.Accordion("Advanced settings", open=False):
151
  max_new_tokens = gr.Slider(16, 4096, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="max_new_tokens")
152
  temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="temperature")
153
  top_p = gr.Slider(0.0, 1.0, value=DEFAULT_TOP_P, step=0.01, label="top_p")
154
  repetition_penalty = gr.Slider(0.8, 2.0, value=DEFAULT_REPETITION_PENALTY, step=0.05, label="repetition_penalty")
155
+ with gr.Row():
156
+ btn = gr.Button("Generate", variant="primary")
157
+ clr = gr.Button("Clear", variant="secondary")
158
+ with gr.Column(scale=1, min_width=380):
159
+ analysis = gr.Textbox(label="Analysis", lines=12)
160
+ answer = gr.Textbox(label="Answer", lines=12)
161
  meta = gr.Markdown()
162
 
163
+ btn.click(
164
  fn=generate_long_prompt,
165
  inputs=[policy, prompt, max_new_tokens, temperature, top_p, repetition_penalty],
166
  outputs=[analysis, answer, meta],
 
168
  api_name="generate",
169
  )
170
 
171
+ def _clear():
172
+ return "", "", "", ""
173
+
174
+ clr.click(_clear, outputs=[policy, prompt, analysis, answer])
175
+
176
+ gr.Examples(
177
+ examples=[
178
+ ["Be concise and refuse unsafe requests.", "Explain transformers in 2 lines."],
179
+ [
180
+ "Friendly teacher: simple explanations, 1 example, end with 3 bullet key takeaways.",
181
+ "What is attention, briefly?",
182
+ ],
183
+ ],
184
+ inputs=[policy, prompt],
185
+ )
186
+
187
  if __name__ == "__main__":
188
  demo.queue(max_size=32).launch()