File size: 6,127 Bytes
d369456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd1b723
d369456
 
 
 
 
 
 
dd1b723
d369456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd1b723
d369456
 
 
 
 
 
 
 
dd1b723
 
d369456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd1b723
 
d369456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd1b723
d369456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd1b723
d369456
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// index.js
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected]';

// DOM elements
const taskSelect = document.getElementById('task-select');
const deviceSelect = document.getElementById('device-select');
const inputText = document.getElementById('input-text');
const runBtn = document.getElementById('run-btn');
const outputContainer = document.getElementById('output');

// Model mapping
const modelMap = {
    'text-classification': 'distilbert-base-uncased-finetuned-sst-2-english',
    'sentiment-analysis': 'distilbert-base-uncased-finetuned-sst-2-english',
    'question-answering': 'deepset/roberta-base-squad2',
    'text-generation': 'gpt2',
    'translation': 't5-small'
};

// Initialize pipeline
let currentPipeline = null;
let currentTask = null;

// Check WebGPU support
function checkWebGPUSupport() {
    return 'gpu' in navigator;
}

// Update device options based on support
if (!checkWebGPUSupport()) {
    const option = document.querySelector('#device-select option[value="webgpu"]');
    option.disabled = true;
    option.textContent = 'GPU (WebGPU) - Not Supported';
    deviceSelect.value = 'cpu';
}

// Function to create pipeline
async function createPipeline(task, device) {
    const modelName = modelMap[task];
    
    // Show loading state
    outputContainer.innerHTML = '<p class="loading">Loading model...</p>';
    
    try {
        // Create pipeline with selected device
        if (device === 'webgpu' && checkWebGPUSupport()) {
            return await pipeline(task, modelName, { device: 'webgpu' });
        } else {
            return await pipeline(task, modelName);
        }
    } catch (error) {
        console.error('Error creating pipeline:', error);
        outputContainer.innerHTML = `<p class="error">Error loading model: ${error.message}</p>`;
        return null;
    }
}

// Function to run the model
async function runModel() {
    const task = taskSelect.value;
    const device = deviceSelect.value;
    const text = inputText.value.trim();
    
    if (!text) {
        outputContainer.innerHTML = '<p class="error">Please enter some text to process.</p>';
        return;
    }
    
    // Check if pipeline needs to be recreated
    if (!currentPipeline || currentTask !== task) {
        currentPipeline = await createPipeline(task, device);
        currentTask = task;
    }
    
    if (!currentPipeline) {
        return;
    }
    
    // Show processing state
    outputContainer.innerHTML = '<p class="loading">Processing...</p>';
    
    try {
        let result;
        
        switch (task) {
            case 'text-classification':
            case 'sentiment-analysis':
                result = await currentPipeline(text);
                break;
            case 'question-answering':
                // For QA, we need context and question
                const [context, question] = text.split('\n').length > 1 
                    ? [text.split('\n').slice(0, -1).join(' '), text.split('\n').pop()] 
                    : ['The sky is blue.', text];
                result = await currentPipeline(question, context);
                break;
            case 'text-generation':
                result = await currentPipeline(text, { max_new_tokens: 50 });
                break;
            case 'translation':
                result = await currentPipeline(text, { 
                    src_lang: 'en', 
                    tgt_lang: 'de' 
                });
                break;
            default:
                throw new Error('Unsupported task');
        }
        
        // Display results
        displayResults(result, task);
    } catch (error) {
        console.error('Error running model:', error);
        outputContainer.innerHTML = `<p class="error">Error processing text: ${error.message}</p>`;
    }
}

// Function to display results based on task
function displayResults(result, task) {
    let content = '';
    
    switch (task) {
        case 'text-classification':
        case 'sentiment-analysis':
            content = `
                <h4>Classification Results:</h4>
                <ul>
                    ${result.map(item => `
                        <li>
                            <strong>${item.label}:</strong> ${(item.score * 100).toFixed(2)}%
                        </li>
                    `).join('')}
                </ul>
            `;
            break;
        case 'question-answering':
            content = `
                <h4>Answer:</h4>
                <p>${result.answer}</p>
                <p><strong>Confidence:</strong> ${(result.score * 100).toFixed(2)}%</p>
            `;
            break;
        case 'text-generation':
            content = `
                <h4>Generated Text:</h4>
                <p>${result[0].generated_text}</p>
            `;
            break;
        case 'translation':
            content = `
                <h4>Translation:</h4>
                <p>${result[0].translation_text}</p>
            `;
            break;
        default:
            content = `<pre>${JSON.stringify(result, null, 2)}</pre>`;
    }
    
    outputContainer.innerHTML = content;
}

// Event listeners
runBtn.addEventListener('click', runModel);

// Handle task change
taskSelect.addEventListener('change', () => {
    // Reset pipeline when task changes
    currentPipeline = null;
    currentTask = null;
    
    // Update placeholder text based on task
    switch (taskSelect.value) {
        case 'question-answering':
            inputText.placeholder = "Enter context and question separated by a new line\nContext: The sky is blue.\nQuestion: What color is the sky?";
            break;
        case 'text-generation':
            inputText.placeholder = "Enter a prompt to generate text from...";
            break;
        case 'translation':
            inputText.placeholder = "Enter text to translate from English to German...";
            break;
        default:
            inputText.placeholder = "Enter your text here...";
    }
});

// Initialize with default task
taskSelect.dispatchEvent(new Event('change'));