Spaces:
Runtime error
Runtime error
Commit
·
ec05364
1
Parent(s):
9e04c4b
update: eval ui
Browse files
application_pages/evaluation_app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
-
import time
|
| 4 |
from importlib import import_module
|
| 5 |
|
| 6 |
import pandas as pd
|
|
@@ -12,212 +11,81 @@ from dotenv import load_dotenv
|
|
| 12 |
from guardrails_genie.guardrails import GuardrailManager
|
| 13 |
from guardrails_genie.llm import OpenAIModel
|
| 14 |
from guardrails_genie.metrics import AccuracyMetric
|
| 15 |
-
from guardrails_genie.utils import EvaluationCallManager
|
| 16 |
|
| 17 |
|
| 18 |
def initialize_session_state():
|
| 19 |
load_dotenv()
|
|
|
|
|
|
|
| 20 |
if "uploaded_file" not in st.session_state:
|
| 21 |
st.session_state.uploaded_file = None
|
| 22 |
if "dataset_name" not in st.session_state:
|
| 23 |
-
st.session_state.dataset_name =
|
| 24 |
if "preview_in_app" not in st.session_state:
|
| 25 |
st.session_state.preview_in_app = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
if "dataset_ref" not in st.session_state:
|
| 27 |
st.session_state.dataset_ref = None
|
| 28 |
-
if "dataset_previewed" not in st.session_state:
|
| 29 |
-
st.session_state.dataset_previewed = False
|
| 30 |
-
if "guardrail_names" not in st.session_state:
|
| 31 |
-
st.session_state.guardrail_names = []
|
| 32 |
-
if "guardrails" not in st.session_state:
|
| 33 |
-
st.session_state.guardrails = []
|
| 34 |
-
if "start_evaluation" not in st.session_state:
|
| 35 |
-
st.session_state.start_evaluation = False
|
| 36 |
-
if "evaluation_summary" not in st.session_state:
|
| 37 |
-
st.session_state.evaluation_summary = None
|
| 38 |
-
if "guardrail_manager" not in st.session_state:
|
| 39 |
-
st.session_state.guardrail_manager = None
|
| 40 |
-
if "evaluation_name" not in st.session_state:
|
| 41 |
-
st.session_state.evaluation_name = ""
|
| 42 |
-
if "show_result_table" not in st.session_state:
|
| 43 |
-
st.session_state.show_result_table = False
|
| 44 |
-
if "weave_client" not in st.session_state:
|
| 45 |
-
st.session_state.weave_client = weave.init(
|
| 46 |
-
project_name=os.getenv("WEAVE_PROJECT")
|
| 47 |
-
)
|
| 48 |
-
if "evaluation_call_manager" not in st.session_state:
|
| 49 |
-
st.session_state.evaluation_call_manager = None
|
| 50 |
-
if "call_id" not in st.session_state:
|
| 51 |
-
st.session_state.call_id = None
|
| 52 |
-
if "llama_guardrail_checkpoint" not in st.session_state:
|
| 53 |
-
st.session_state.llama_guardrail_checkpoint = None
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def initialize_guardrail():
|
| 57 |
-
guardrails = []
|
| 58 |
-
for guardrail_name in st.session_state.guardrail_names:
|
| 59 |
-
if guardrail_name == "PromptInjectionSurveyGuardrail":
|
| 60 |
-
survey_guardrail_model = st.sidebar.selectbox(
|
| 61 |
-
"Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"]
|
| 62 |
-
)
|
| 63 |
-
if survey_guardrail_model:
|
| 64 |
-
guardrails.append(
|
| 65 |
-
getattr(
|
| 66 |
-
import_module("guardrails_genie.guardrails"),
|
| 67 |
-
guardrail_name,
|
| 68 |
-
)(llm_model=OpenAIModel(model_name=survey_guardrail_model))
|
| 69 |
-
)
|
| 70 |
-
elif guardrail_name == "PromptInjectionClassifierGuardrail":
|
| 71 |
-
classifier_model_name = st.sidebar.selectbox(
|
| 72 |
-
"Classifier Guardrail Model",
|
| 73 |
-
[
|
| 74 |
-
"",
|
| 75 |
-
"ProtectAI/deberta-v3-base-prompt-injection-v2",
|
| 76 |
-
"wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3",
|
| 77 |
-
],
|
| 78 |
-
)
|
| 79 |
-
if classifier_model_name:
|
| 80 |
-
st.session_state.guardrails.append(
|
| 81 |
-
getattr(
|
| 82 |
-
import_module("guardrails_genie.guardrails"),
|
| 83 |
-
guardrail_name,
|
| 84 |
-
)(model_name=classifier_model_name)
|
| 85 |
-
)
|
| 86 |
-
elif guardrail_name == "PromptInjectionLlamaGuardrail":
|
| 87 |
-
llama_guardrail_checkpoint = st.sidebar.text_input(
|
| 88 |
-
"Llama Guardrail Checkpoint",
|
| 89 |
-
value=None,
|
| 90 |
-
)
|
| 91 |
-
st.session_state.llama_guardrail_checkpoint = llama_guardrail_checkpoint
|
| 92 |
-
if st.session_state.llama_guardrail_checkpoint is not None:
|
| 93 |
-
st.session_state.guardrails.append(
|
| 94 |
-
getattr(
|
| 95 |
-
import_module("guardrails_genie.guardrails"),
|
| 96 |
-
guardrail_name,
|
| 97 |
-
)(checkpoint=st.session_state.llama_guardrail_checkpoint)
|
| 98 |
-
)
|
| 99 |
-
else:
|
| 100 |
-
st.session_state.guardrails.append(
|
| 101 |
-
getattr(
|
| 102 |
-
import_module("guardrails_genie.guardrails"),
|
| 103 |
-
guardrail_name,
|
| 104 |
-
)()
|
| 105 |
-
)
|
| 106 |
-
st.session_state.guardrails = guardrails
|
| 107 |
-
st.session_state.guardrail_manager = GuardrailManager(guardrails=guardrails)
|
| 108 |
|
| 109 |
|
| 110 |
initialize_session_state()
|
| 111 |
st.title(":material/monitoring: Evaluation")
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
uploaded_file = st.sidebar.file_uploader(
|
| 114 |
"Upload the evaluation dataset as a CSV file", type="csv"
|
| 115 |
)
|
| 116 |
st.session_state.uploaded_file = uploaded_file
|
| 117 |
-
dataset_name = st.sidebar.text_input("Evaluation dataset name", value="")
|
| 118 |
-
st.session_state.dataset_name = dataset_name
|
| 119 |
-
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
| 120 |
-
st.session_state.preview_in_app = preview_in_app
|
| 121 |
-
|
| 122 |
-
if st.session_state.uploaded_file is not None and st.session_state.dataset_name != "":
|
| 123 |
-
with st.expander("Evaluation Dataset Preview", expanded=True):
|
| 124 |
-
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
| 125 |
-
data_list = dataframe.to_dict(orient="records")
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
if st.session_state.dataset_previewed:
|
| 144 |
-
guardrail_names = st.sidebar.multiselect(
|
| 145 |
-
"Select Guardrails",
|
| 146 |
-
options=[
|
| 147 |
-
cls_name
|
| 148 |
-
for cls_name, cls_obj in vars(
|
| 149 |
-
import_module("guardrails_genie.guardrails")
|
| 150 |
-
).items()
|
| 151 |
-
if isinstance(cls_obj, type) and cls_name != "GuardrailManager"
|
| 152 |
-
],
|
| 153 |
-
)
|
| 154 |
-
st.session_state.guardrail_names = guardrail_names
|
| 155 |
-
|
| 156 |
-
if st.session_state.guardrail_names != []:
|
| 157 |
-
initialize_guardrail()
|
| 158 |
-
evaluation_name = st.sidebar.text_input("Evaluation name", value="")
|
| 159 |
-
st.session_state.evaluation_name = evaluation_name
|
| 160 |
-
if st.session_state.guardrail_manager is not None:
|
| 161 |
-
if st.sidebar.button("Start Evaluation"):
|
| 162 |
-
st.session_state.start_evaluation = True
|
| 163 |
-
if st.session_state.start_evaluation:
|
| 164 |
-
evaluation = weave.Evaluation(
|
| 165 |
-
dataset=st.session_state.dataset_ref,
|
| 166 |
-
scorers=[AccuracyMetric()],
|
| 167 |
-
streamlit_mode=True,
|
| 168 |
-
)
|
| 169 |
-
with st.expander("Evaluation Results", expanded=True):
|
| 170 |
-
evaluation_summary, call = asyncio.run(
|
| 171 |
-
evaluation.evaluate.call(
|
| 172 |
-
evaluation,
|
| 173 |
-
st.session_state.guardrail_manager,
|
| 174 |
-
__weave={
|
| 175 |
-
"display_name": "Evaluation.evaluate:"
|
| 176 |
-
+ st.session_state.evaluation_name
|
| 177 |
-
},
|
| 178 |
-
)
|
| 179 |
-
)
|
| 180 |
-
x_axis = list(evaluation_summary["AccuracyMetric"].keys())
|
| 181 |
-
y_axis = [
|
| 182 |
-
evaluation_summary["AccuracyMetric"][x_axis_item]
|
| 183 |
-
for x_axis_item in x_axis
|
| 184 |
-
]
|
| 185 |
-
st.bar_chart(
|
| 186 |
-
pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
|
| 187 |
-
x="Metric",
|
| 188 |
-
y="Score",
|
| 189 |
)
|
| 190 |
-
st.session_state.evaluation_summary = evaluation_summary
|
| 191 |
-
st.session_state.call_id = call.id
|
| 192 |
-
st.session_state.start_evaluation = False
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
entity="geekyrakshit",
|
| 199 |
-
project="guardrails-genie",
|
| 200 |
-
call_id=st.session_state.call_id,
|
| 201 |
-
)
|
| 202 |
-
)
|
| 203 |
-
for guardrail_name in st.session_state.guardrail_names:
|
| 204 |
-
st.session_state.evaluation_call_manager.call_list.append(
|
| 205 |
-
{
|
| 206 |
-
"guardrail_name": guardrail_name,
|
| 207 |
-
"calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(),
|
| 208 |
-
}
|
| 209 |
-
)
|
| 210 |
-
rich.print(
|
| 211 |
-
st.session_state.evaluation_call_manager.call_list
|
| 212 |
-
)
|
| 213 |
-
st.dataframe(
|
| 214 |
-
st.session_state.evaluation_call_manager.render_calls_to_streamlit()
|
| 215 |
-
)
|
| 216 |
-
if st.session_state.evaluation_call_manager.show_warning_in_app:
|
| 217 |
-
st.warning(
|
| 218 |
-
f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
|
| 219 |
-
)
|
| 220 |
-
st.markdown(
|
| 221 |
-
f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
|
| 222 |
-
)
|
| 223 |
-
st.session_state.evaluation_call_manager = None
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
|
|
|
| 3 |
from importlib import import_module
|
| 4 |
|
| 5 |
import pandas as pd
|
|
|
|
| 11 |
from guardrails_genie.guardrails import GuardrailManager
|
| 12 |
from guardrails_genie.llm import OpenAIModel
|
| 13 |
from guardrails_genie.metrics import AccuracyMetric
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def initialize_session_state():
|
| 17 |
load_dotenv()
|
| 18 |
+
if "weave_project_name" not in st.session_state:
|
| 19 |
+
st.session_state.weave_project_name = "guardrails-genie"
|
| 20 |
if "uploaded_file" not in st.session_state:
|
| 21 |
st.session_state.uploaded_file = None
|
| 22 |
if "dataset_name" not in st.session_state:
|
| 23 |
+
st.session_state.dataset_name = None
|
| 24 |
if "preview_in_app" not in st.session_state:
|
| 25 |
st.session_state.preview_in_app = False
|
| 26 |
+
if "is_dataset_published" not in st.session_state:
|
| 27 |
+
st.session_state.is_dataset_published = False
|
| 28 |
+
if "publish_dataset_button" not in st.session_state:
|
| 29 |
+
st.session_state.publish_dataset_button = False
|
| 30 |
if "dataset_ref" not in st.session_state:
|
| 31 |
st.session_state.dataset_ref = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
initialize_session_state()
|
| 35 |
st.title(":material/monitoring: Evaluation")
|
| 36 |
|
| 37 |
+
weave_project_name = st.sidebar.text_input(
|
| 38 |
+
"Weave project name", value=st.session_state.weave_project_name
|
| 39 |
+
)
|
| 40 |
+
st.session_state.weave_project_name = weave_project_name
|
| 41 |
+
if st.session_state.weave_project_name != "":
|
| 42 |
+
weave.init(project_name=st.session_state.weave_project_name)
|
| 43 |
+
|
| 44 |
uploaded_file = st.sidebar.file_uploader(
|
| 45 |
"Upload the evaluation dataset as a CSV file", type="csv"
|
| 46 |
)
|
| 47 |
st.session_state.uploaded_file = uploaded_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
if st.session_state.uploaded_file is not None:
|
| 50 |
+
dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None)
|
| 51 |
+
st.session_state.dataset_name = dataset_name
|
| 52 |
+
preview_in_app = st.sidebar.toggle("Preview in app", value=False)
|
| 53 |
+
st.session_state.preview_in_app = preview_in_app
|
| 54 |
+
publish_dataset_button = st.sidebar.button("Publish dataset")
|
| 55 |
+
st.session_state.publish_dataset_button = publish_dataset_button
|
| 56 |
+
|
| 57 |
+
if (
|
| 58 |
+
st.session_state.publish_dataset_button
|
| 59 |
+
and (
|
| 60 |
+
st.session_state.dataset_name is not None
|
| 61 |
+
and st.session_state.dataset_name != ""
|
| 62 |
)
|
| 63 |
+
):
|
| 64 |
+
|
| 65 |
+
with st.expander("Evaluation Dataset Preview", expanded=True):
|
| 66 |
+
dataframe = pd.read_csv(st.session_state.uploaded_file)
|
| 67 |
+
data_list = dataframe.to_dict(orient="records")
|
| 68 |
+
|
| 69 |
+
dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list)
|
| 70 |
+
st.session_state.dataset_ref = weave.publish(dataset)
|
| 71 |
+
|
| 72 |
+
entity = st.session_state.dataset_ref.entity
|
| 73 |
+
project = st.session_state.dataset_ref.project
|
| 74 |
+
dataset_name = st.session_state.dataset_name
|
| 75 |
+
digest = st.session_state.dataset_ref._digest
|
| 76 |
+
dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}"
|
| 77 |
+
st.markdown(
|
| 78 |
+
f"Dataset published to [**Weave**]({dataset_url})"
|
| 79 |
+
)
|
| 80 |
|
| 81 |
+
if preview_in_app:
|
| 82 |
+
st.dataframe(dataframe.head(20))
|
| 83 |
+
if len(dataframe) > 20:
|
| 84 |
+
st.markdown(
|
| 85 |
+
f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
)
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
st.session_state.is_dataset_published = True
|
| 89 |
+
|
| 90 |
+
if st.session_state.is_dataset_published:
|
| 91 |
+
st.write("Maza Ayega")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
guardrails_genie/train/llama_guard.py
CHANGED
|
@@ -3,12 +3,13 @@ import shutil
|
|
| 3 |
from glob import glob
|
| 4 |
from typing import Optional
|
| 5 |
|
|
|
|
|
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import streamlit as st
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
-
import torch.optim as optim
|
| 12 |
from datasets import load_dataset
|
| 13 |
from pydantic import BaseModel
|
| 14 |
from rich.progress import track
|
|
@@ -335,8 +336,8 @@ class LlamaGuardFineTuner:
|
|
| 335 |
|
| 336 |
def train(
|
| 337 |
self,
|
| 338 |
-
batch_size: int =
|
| 339 |
-
|
| 340 |
num_classes: int = 2,
|
| 341 |
log_interval: int = 1,
|
| 342 |
save_interval: int = 50,
|
|
@@ -358,7 +359,7 @@ class LlamaGuardFineTuner:
|
|
| 358 |
|
| 359 |
Args:
|
| 360 |
batch_size (int, optional): The number of samples per batch during training.
|
| 361 |
-
|
| 362 |
num_classes (int, optional): The number of output classes for the classifier.
|
| 363 |
log_interval (int, optional): The interval (in batches) at which to log the loss.
|
| 364 |
save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
|
|
@@ -377,7 +378,7 @@ class LlamaGuardFineTuner:
|
|
| 377 |
wandb.config.dataset_args = self.dataset_args.model_dump()
|
| 378 |
wandb.config.model_name = self.model_name
|
| 379 |
wandb.config.batch_size = batch_size
|
| 380 |
-
wandb.config.
|
| 381 |
wandb.config.num_classes = num_classes
|
| 382 |
wandb.config.log_interval = log_interval
|
| 383 |
wandb.config.save_interval = save_interval
|
|
@@ -387,7 +388,16 @@ class LlamaGuardFineTuner:
|
|
| 387 |
self.model.num_labels = num_classes
|
| 388 |
self.model = self.model.to(self.device)
|
| 389 |
self.model.train()
|
| 390 |
-
optimizer = optim.AdamW(self.model.parameters(), lr=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
data_loader = DataLoader(
|
| 392 |
self.train_dataset,
|
| 393 |
batch_size=batch_size,
|
|
@@ -405,9 +415,14 @@ class LlamaGuardFineTuner:
|
|
| 405 |
loss = outputs.loss
|
| 406 |
optimizer.zero_grad()
|
| 407 |
loss.backward()
|
|
|
|
|
|
|
|
|
|
| 408 |
optimizer.step()
|
|
|
|
| 409 |
if (i + 1) % log_interval == 0:
|
| 410 |
wandb.log({"loss": loss.item()}, step=i + 1)
|
|
|
|
| 411 |
if progress_bar:
|
| 412 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
| 413 |
progress_bar.progress(
|
|
|
|
| 3 |
from glob import glob
|
| 4 |
from typing import Optional
|
| 5 |
|
| 6 |
+
# import torch.optim as optim
|
| 7 |
+
import bitsandbytes.optim as optim
|
| 8 |
import plotly.graph_objects as go
|
| 9 |
import streamlit as st
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
|
|
|
| 13 |
from datasets import load_dataset
|
| 14 |
from pydantic import BaseModel
|
| 15 |
from rich.progress import track
|
|
|
|
| 336 |
|
| 337 |
def train(
|
| 338 |
self,
|
| 339 |
+
batch_size: int = 16,
|
| 340 |
+
starting_lr: float = 1e-7,
|
| 341 |
num_classes: int = 2,
|
| 342 |
log_interval: int = 1,
|
| 343 |
save_interval: int = 50,
|
|
|
|
| 359 |
|
| 360 |
Args:
|
| 361 |
batch_size (int, optional): The number of samples per batch during training.
|
| 362 |
+
starting_lr (float, optional): The starting learning rate for the optimizer.
|
| 363 |
num_classes (int, optional): The number of output classes for the classifier.
|
| 364 |
log_interval (int, optional): The interval (in batches) at which to log the loss.
|
| 365 |
save_interval (int, optional): The interval (in batches) at which to save model checkpoints.
|
|
|
|
| 378 |
wandb.config.dataset_args = self.dataset_args.model_dump()
|
| 379 |
wandb.config.model_name = self.model_name
|
| 380 |
wandb.config.batch_size = batch_size
|
| 381 |
+
wandb.config.starting_lr = starting_lr
|
| 382 |
wandb.config.num_classes = num_classes
|
| 383 |
wandb.config.log_interval = log_interval
|
| 384 |
wandb.config.save_interval = save_interval
|
|
|
|
| 388 |
self.model.num_labels = num_classes
|
| 389 |
self.model = self.model.to(self.device)
|
| 390 |
self.model.train()
|
| 391 |
+
# optimizer = optim.AdamW(self.model.parameters(), lr=starting_lr)
|
| 392 |
+
optimizer = optim.Lion(
|
| 393 |
+
self.model.parameters(), lr=starting_lr, weight_decay=0.01
|
| 394 |
+
)
|
| 395 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 396 |
+
optimizer,
|
| 397 |
+
max_lr=starting_lr,
|
| 398 |
+
steps_per_epoch=len(self.train_dataset) // batch_size + 1,
|
| 399 |
+
epochs=1,
|
| 400 |
+
)
|
| 401 |
data_loader = DataLoader(
|
| 402 |
self.train_dataset,
|
| 403 |
batch_size=batch_size,
|
|
|
|
| 415 |
loss = outputs.loss
|
| 416 |
optimizer.zero_grad()
|
| 417 |
loss.backward()
|
| 418 |
+
|
| 419 |
+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clipping)
|
| 420 |
+
|
| 421 |
optimizer.step()
|
| 422 |
+
scheduler.step()
|
| 423 |
if (i + 1) % log_interval == 0:
|
| 424 |
wandb.log({"loss": loss.item()}, step=i + 1)
|
| 425 |
+
wandb.log({"learning_rate": scheduler.get_last_lr()[0]}, step=i + 1)
|
| 426 |
if progress_bar:
|
| 427 |
progress_percentage = (i + 1) * 100 // len(data_loader)
|
| 428 |
progress_bar.progress(
|