Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -999,30 +999,34 @@ with ui.navset_card_tab(id="tab"):
|
|
| 999 |
multiple=True,
|
| 1000 |
selected=["compliment", "cross_entropy", "headless"]
|
| 1001 |
)
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
for
|
| 1010 |
-
for
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
|
|
|
| 1014 |
f = interp1d(np.linspace(0, 1, len(y)), y)
|
| 1015 |
loss_rates.append(f(x))
|
| 1016 |
-
labels.append(str(param_type) +'_'+loss_type +'_'+model_type)
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
|
|
|
|
|
|
|
|
|
| 1026 |
|
| 1027 |
import matplotlib as mpl
|
| 1028 |
@render.plot()
|
|
|
|
| 999 |
multiple=True,
|
| 1000 |
selected=["compliment", "cross_entropy", "headless"]
|
| 1001 |
)
|
| 1002 |
+
def plot_loss_rates_model(df, param_types, loss_types, model_types):
|
| 1003 |
+
# interplot each column to be same number of points
|
| 1004 |
+
x = np.linspace(0, 1, 1000)
|
| 1005 |
+
loss_rates = []
|
| 1006 |
+
labels = []
|
| 1007 |
+
|
| 1008 |
+
for param_type in param_types:
|
| 1009 |
+
for loss_type in loss_types:
|
| 1010 |
+
for model_type in model_types:
|
| 1011 |
+
y = df[(df['param_type'] == param_type) & (df['loss_type'] == loss_type) & (df['model_type'] == model_type)]['loss'].astype('float').values
|
| 1012 |
+
print(y)
|
| 1013 |
+
|
| 1014 |
+
if len(y) > 0:
|
| 1015 |
f = interp1d(np.linspace(0, 1, len(y)), y)
|
| 1016 |
loss_rates.append(f(x))
|
| 1017 |
+
labels.append(str(param_type) + '_' + loss_type + '_' + model_type)
|
| 1018 |
+
|
| 1019 |
+
fig, ax = plt.subplots()
|
| 1020 |
+
print(loss_rates)
|
| 1021 |
+
|
| 1022 |
+
for i, loss_rate in enumerate(loss_rates):
|
| 1023 |
+
ax.plot(x, loss_rate, label=labels[i])
|
| 1024 |
+
|
| 1025 |
+
ax.legend()
|
| 1026 |
+
ax.set_xlabel('Training steps')
|
| 1027 |
+
ax.set_ylabel('Loss rate')
|
| 1028 |
+
|
| 1029 |
+
return fig
|
| 1030 |
|
| 1031 |
import matplotlib as mpl
|
| 1032 |
@render.plot()
|