Spaces:
Paused
Paused
Upload 13 files
Browse files- app.py +728 -0
- codegen_torch.py +187 -0
- gpt2_pytorch.py +210 -0
- image_to_3d_openlrm.py +31 -0
- imagegen_vae_unet.py +164 -0
- lipsync_wav2lip.py +57 -0
- musicgen_torch.py +36 -0
- sentiment_roberta.py +195 -0
- stt_wav2vec2.py +46 -0
- summarization_bart.py +34 -0
- text_to_video_clip4clip.py +34 -0
- translation_mbart.py +267 -0
- tts_vits.py +57 -0
app.py
ADDED
|
@@ -0,0 +1,728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import copy
|
| 9 |
+
import requests
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from torch.nn.parameter import Parameter
|
| 13 |
+
from sklearn.datasets import fetch_20newsgroups
|
| 14 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 15 |
+
from sklearn.linear_model import LogisticRegression
|
| 16 |
+
from sklearn.multiclass import OneVsRestClassifier
|
| 17 |
+
import time
|
| 18 |
+
import threading
|
| 19 |
+
import queue
|
| 20 |
+
import httpx
|
| 21 |
+
import asyncio
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import uuid
|
| 25 |
+
import wget
|
| 26 |
+
from duckduckgo_search import DDGS
|
| 27 |
+
import warnings
|
| 28 |
+
from datetime import datetime
|
| 29 |
+
import unicodedata
|
| 30 |
+
import nltk
|
| 31 |
+
import torchaudio
|
| 32 |
+
import logging
|
| 33 |
+
from PIL import Image
|
| 34 |
+
from io import BytesIO
|
| 35 |
+
import sentencepiece as spm
|
| 36 |
+
from flask import Flask, request, jsonify, send_file, Response
|
| 37 |
+
from flask_cors import CORS
|
| 38 |
+
|
| 39 |
+
nltk.download('punkt', quiet=True)
|
| 40 |
+
|
| 41 |
+
GPT2_FOLDER = "./GPT2"
|
| 42 |
+
MODEL_FILE = "gpt2-pytorch_model.bin"
|
| 43 |
+
ENCODER_FILE = "encoder.json"
|
| 44 |
+
VOCAB_FILE = "vocab.bpe"
|
| 45 |
+
MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
|
| 46 |
+
ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/encoder.json"
|
| 47 |
+
VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/vocab.bpe"
|
| 48 |
+
GPT2_FILES_URLS = [
|
| 49 |
+
(MODEL_URL, MODEL_FILE),
|
| 50 |
+
(ENCODER_URL, ENCODER_FILE),
|
| 51 |
+
(VOCAB_URL, VOCAB_FILE),
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
TEXT_GENERATION_RATE = 40000
|
| 55 |
+
MAX_LENGTH = 1024
|
| 56 |
+
MAX_XDD = 5
|
| 57 |
+
END_OF_TEXT_TOKEN = "<|endoftext|>"
|
| 58 |
+
|
| 59 |
+
html_code = """<!DOCTYPE html>
|
| 60 |
+
<html lang="en">
|
| 61 |
+
<head>
|
| 62 |
+
<meta charset="UTF-8">
|
| 63 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 64 |
+
<title>AI Text Generation</title>
|
| 65 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
|
| 66 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
|
| 67 |
+
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
| 68 |
+
<style>
|
| 69 |
+
body {
|
| 70 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
| 71 |
+
background: #f0f0f0;
|
| 72 |
+
color: #333;
|
| 73 |
+
margin: 0;
|
| 74 |
+
padding: 0;
|
| 75 |
+
display: flex;
|
| 76 |
+
flex-direction: column;
|
| 77 |
+
align-items: center;
|
| 78 |
+
min-height: 100vh;
|
| 79 |
+
}
|
| 80 |
+
.container {
|
| 81 |
+
width: 95%;
|
| 82 |
+
max-width: 900px;
|
| 83 |
+
padding: 20px;
|
| 84 |
+
background-color: #fff;
|
| 85 |
+
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
| 86 |
+
border-radius: 8px;
|
| 87 |
+
margin-top: 20px;
|
| 88 |
+
margin-bottom: 20px;
|
| 89 |
+
display: flex;
|
| 90 |
+
flex-direction: column;
|
| 91 |
+
}
|
| 92 |
+
.header {
|
| 93 |
+
text-align: center;
|
| 94 |
+
margin-bottom: 20px;
|
| 95 |
+
}
|
| 96 |
+
.header h1 {
|
| 97 |
+
font-size: 2em;
|
| 98 |
+
color: #333;
|
| 99 |
+
}
|
| 100 |
+
.form-group {
|
| 101 |
+
margin-bottom: 15px;
|
| 102 |
+
}
|
| 103 |
+
.form-group textarea {
|
| 104 |
+
width: 100%;
|
| 105 |
+
padding: 10px;
|
| 106 |
+
border: 1px solid #ccc;
|
| 107 |
+
border-radius: 5px;
|
| 108 |
+
font-size: 16px;
|
| 109 |
+
box-sizing: border-box;
|
| 110 |
+
resize: vertical;
|
| 111 |
+
}
|
| 112 |
+
button {
|
| 113 |
+
padding: 10px 15px;
|
| 114 |
+
border: none;
|
| 115 |
+
border-radius: 5px;
|
| 116 |
+
background-color: #007bff;
|
| 117 |
+
color: white;
|
| 118 |
+
font-size: 18px;
|
| 119 |
+
cursor: pointer;
|
| 120 |
+
transition: background-color 0.3s ease;
|
| 121 |
+
}
|
| 122 |
+
button:hover {
|
| 123 |
+
background-color: #0056b3;
|
| 124 |
+
}
|
| 125 |
+
#output {
|
| 126 |
+
margin-top: 20px;
|
| 127 |
+
padding: 15px;
|
| 128 |
+
border: 1px solid #ddd;
|
| 129 |
+
border-radius: 5px;
|
| 130 |
+
background-color: #f9f9f9;
|
| 131 |
+
white-space: pre-wrap;
|
| 132 |
+
word-break: break-word;
|
| 133 |
+
overflow-y: auto;
|
| 134 |
+
max-height: 100vh;
|
| 135 |
+
}
|
| 136 |
+
#output strong {
|
| 137 |
+
font-weight: bold;
|
| 138 |
+
}
|
| 139 |
+
.animated-text {
|
| 140 |
+
position: fixed;
|
| 141 |
+
top: 20px;
|
| 142 |
+
left: 20px;
|
| 143 |
+
font-size: 1.5em;
|
| 144 |
+
color: rgba(0, 0, 0, 0.1);
|
| 145 |
+
pointer-events: none;
|
| 146 |
+
z-index: -1;
|
| 147 |
+
}
|
| 148 |
+
@media (max-width: 768px) {
|
| 149 |
+
.container {
|
| 150 |
+
width: 98%;
|
| 151 |
+
margin-top: 10px;
|
| 152 |
+
margin-bottom: 10px;
|
| 153 |
+
padding: 15px;
|
| 154 |
+
}
|
| 155 |
+
.header h1 {
|
| 156 |
+
font-size: 1.8em;
|
| 157 |
+
}
|
| 158 |
+
.form-group textarea, .form-group input[type="text"] {
|
| 159 |
+
font-size: 14px;
|
| 160 |
+
padding: 8px;
|
| 161 |
+
}
|
| 162 |
+
button {
|
| 163 |
+
font-size: 16px;
|
| 164 |
+
padding: 8px 12px;
|
| 165 |
+
}
|
| 166 |
+
#output {
|
| 167 |
+
font-size: 14px;
|
| 168 |
+
padding: 10px;
|
| 169 |
+
margin-top: 15px;
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
</style>
|
| 173 |
+
</head>
|
| 174 |
+
<body>
|
| 175 |
+
<div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
|
| 176 |
+
<div class="container">
|
| 177 |
+
<div class="header animate__animated animate__fadeInDown">
|
| 178 |
+
</div>
|
| 179 |
+
<div class="form-group animate__animated animate__fadeInLeft">
|
| 180 |
+
<textarea id="text" rows="5" placeholder="Enter text"></textarea>
|
| 181 |
+
</div>
|
| 182 |
+
<button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
|
| 183 |
+
<div id="output" class="animate__animated">
|
| 184 |
+
<strong >Response:</strong><br>
|
| 185 |
+
<div id="generatedText"></div>
|
| 186 |
+
</div>
|
| 187 |
+
</div>
|
| 188 |
+
<script>
|
| 189 |
+
let eventSource = null;
|
| 190 |
+
let accumulatedText = "";
|
| 191 |
+
let lastResponse = "";
|
| 192 |
+
let currentSpan = null;
|
| 193 |
+
let messageCounter = 0;
|
| 194 |
+
|
| 195 |
+
async function generateText() {
|
| 196 |
+
const inputText = document.getElementById("text").value;
|
| 197 |
+
const generatedTextDiv = document.getElementById("generatedText");
|
| 198 |
+
generatedTextDiv.innerHTML = "";
|
| 199 |
+
accumulatedText = "";
|
| 200 |
+
lastResponse = "";
|
| 201 |
+
currentSpan = null;
|
| 202 |
+
messageCounter = 0;
|
| 203 |
+
|
| 204 |
+
if (eventSource) {
|
| 205 |
+
eventSource.close();
|
| 206 |
+
}
|
| 207 |
+
const temp = 0.7;
|
| 208 |
+
const top_k_val = 40;
|
| 209 |
+
const top_p_val = 0.0;
|
| 210 |
+
const repetition_penalty_val = 1.2;
|
| 211 |
+
eventSource = new EventSource(`/generate_stream?text=${encodeURIComponent(inputText)}&temp=${temp}&top_k=${top_k_val}&top_p=${top_p_val}&reppenalty=${reppenalty_val}`);
|
| 212 |
+
eventSource.onmessage = function(event) {
|
| 213 |
+
if (event.data === "<END_STREAM>") {
|
| 214 |
+
eventSource.close();
|
| 215 |
+
const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(re.compile(r'\\s+(?=[.,,。])'), '').trim();
|
| 216 |
+
if (currentResponse === lastResponse.trim()) {
|
| 217 |
+
accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
|
| 218 |
+
} else {
|
| 219 |
+
lastResponse = currentResponse;
|
| 220 |
+
}
|
| 221 |
+
document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
|
| 222 |
+
return;
|
| 223 |
+
}
|
| 224 |
+
try {
|
| 225 |
+
const jsonData = JSON.parse(event.data);
|
| 226 |
+
const token = jsonData.token;
|
| 227 |
+
if (token === "<|endoftext|>" || token === "<END_STREAM>") {
|
| 228 |
+
return;
|
| 229 |
+
}
|
| 230 |
+
if (token === "<NEW_MESSAGE>") {
|
| 231 |
+
messageCounter++;
|
| 232 |
+
if (messageCounter > 1) {
|
| 233 |
+
generatedTextDiv.innerHTML += "<br><br><hr style='border-top: 1px dashed #8c8b8b; margin-top: 10px; margin-bottom: 10px;'><strong>Continued Response:</strong><br><div id='generatedText_" + messageCounter + "'></div>";
|
| 234 |
+
generatedTextDiv = document.getElementById("generatedText_" + messageCounter);
|
| 235 |
+
accumulatedText = "";
|
| 236 |
+
}
|
| 237 |
+
return;
|
| 238 |
+
}
|
| 239 |
+
accumulatedText += token + " ";
|
| 240 |
+
} catch (e) {
|
| 241 |
+
console.error("Error parsing SSE data:", event.data, e);
|
| 242 |
+
}
|
| 243 |
+
};
|
| 244 |
+
eventSource.onerror = function(error) {
|
| 245 |
+
console.error("SSE error", error);
|
| 246 |
+
eventSource.close();
|
| 247 |
+
};
|
| 248 |
+
const outputDiv = document.getElementById("output");
|
| 249 |
+
outputDiv.classList.add("show");
|
| 250 |
+
}
|
| 251 |
+
</script>
|
| 252 |
+
</body>
|
| 253 |
+
</html>
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
TRANSLATION_FOLDER = "./TranslationModel"
|
| 257 |
+
TRANSLATION_MODEL_WEIGHTS_FILE = "pytorch_model.bin"
|
| 258 |
+
TRANSLATION_MODEL_CONFIG_FILE = "config.json"
|
| 259 |
+
TRANSLATION_MODEL_VOCAB_FILE = "sentencepiece.bpe.model"
|
| 260 |
+
TRANSLATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/pytorch_model.bin"
|
| 261 |
+
TRANSLATION_MODEL_CONFIG_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json"
|
| 262 |
+
TRANSLATION_MODEL_VOCAB_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
| 263 |
+
TRANSLATION_MODEL_FILES_URLS = [
|
| 264 |
+
(TRANSLATION_MODEL_WEIGHTS_URL, TRANSLATION_MODEL_WEIGHTS_FILE),
|
| 265 |
+
(TRANSLATION_MODEL_CONFIG_URL, TRANSLATION_MODEL_CONFIG_FILE),
|
| 266 |
+
(TRANSLATION_MODEL_VOCAB_URL, TRANSLATION_MODEL_VOCAB_FILE),
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
CODEGEN_FOLDER = "./CodeGenModel"
|
| 270 |
+
CODEGEN_MODEL_NAME = "codegen-350M-multi"
|
| 271 |
+
CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 272 |
+
CODEGEN_CONFIG = "config.json"
|
| 273 |
+
CODEGEN_VOCAB = "vocab.json"
|
| 274 |
+
CODEGEN_MERGES = "merges.txt"
|
| 275 |
+
CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
|
| 276 |
+
CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
|
| 277 |
+
CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
|
| 278 |
+
CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
|
| 279 |
+
CODEGEN_FILES_URLS = [
|
| 280 |
+
(CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
|
| 281 |
+
(CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
|
| 282 |
+
(CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
|
| 283 |
+
(CODEGEN_MERGES_URL, CODEGEN_MERGES),
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
TTS_FOLDER = "./TTSModel"
|
| 287 |
+
TTS_MODEL_NAME = "vits"
|
| 288 |
+
TTS_MODEL_CONFIG = "config.json"
|
| 289 |
+
TTS_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 290 |
+
TTS_VOCAB = "vocab.json"
|
| 291 |
+
TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json"
|
| 292 |
+
TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin"
|
| 293 |
+
TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json"
|
| 294 |
+
TTS_FILES_URLS = [
|
| 295 |
+
(TTS_CONFIG_URL, TTS_MODEL_CONFIG),
|
| 296 |
+
(TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS),
|
| 297 |
+
(TTS_VOCAB_URL, TTS_VOCAB),
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
STT_FOLDER = "./STTModel"
|
| 301 |
+
STT_MODEL_NAME = "wav2vec2"
|
| 302 |
+
STT_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 303 |
+
STT_CONFIG = "config.json"
|
| 304 |
+
STT_VOCAB = "vocab.json"
|
| 305 |
+
STT_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin"
|
| 306 |
+
STT_CONFIG_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json"
|
| 307 |
+
STT_VOCAB_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
|
| 308 |
+
STT_FILES_URLS = [
|
| 309 |
+
(STT_MODEL_WEIGHTS_URL, STT_MODEL_WEIGHTS),
|
| 310 |
+
(STT_CONFIG_URL, STT_CONFIG),
|
| 311 |
+
(STT_VOCAB_URL, STT_VOCAB),
|
| 312 |
+
]
|
| 313 |
+
|
| 314 |
+
SENTIMENT_FOLDER = "./SentimentModel"
|
| 315 |
+
SENTIMENT_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 316 |
+
SENTIMENT_VOCAB = "sentiment_vocab.json"
|
| 317 |
+
SENTIMENT_CONFIG = "config.json"
|
| 318 |
+
SENTIMENT_MODEL_WEIGHTS_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/pytorch_model.bin"
|
| 319 |
+
SENTIMENT_VOCAB_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/vocab.json"
|
| 320 |
+
SENTIMENT_CONFIG_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/config.json"
|
| 321 |
+
SENTIMENT_FILES_URLS = [
|
| 322 |
+
(SENTIMENT_MODEL_WEIGHTS_URL, SENTIMENT_MODEL_WEIGHTS),
|
| 323 |
+
(SENTIMENT_VOCAB_URL, SENTIMENT_VOCAB),
|
| 324 |
+
(SENTIMENT_CONFIG_URL, SENTIMENT_CONFIG),
|
| 325 |
+
]
|
| 326 |
+
|
| 327 |
+
IMAGEGEN_FOLDER = "./ImageGenModel"
|
| 328 |
+
IMAGEGEN_MODEL_WEIGHTS = "diffusion_pytorch_model.bin"
|
| 329 |
+
IMAGEGEN_CONFIG = "config.json"
|
| 330 |
+
IMAGEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin"
|
| 331 |
+
IMAGEGEN_CONFIG_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json"
|
| 332 |
+
IMAGEGEN_FILES_URLS = [
|
| 333 |
+
(IMAGEGEN_MODEL_WEIGHTS_URL, IMAGEGEN_MODEL_WEIGHTS),
|
| 334 |
+
(IMAGEGEN_CONFIG_URL, IMAGEGEN_CONFIG),
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
LIPSYNC_FOLDER = "./LipSyncModel"
|
| 338 |
+
LIPSYNC_MODEL_WEIGHTS = "lipsync_expert.pth"
|
| 339 |
+
LIPSYNC_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Flipsync%5Fexpert%2Epth"
|
| 340 |
+
LIPSYNC_FILES_URLS = [
|
| 341 |
+
(LIPSYNC_MODEL_WEIGHTS_URL, LIPSYNC_MODEL_WEIGHTS),
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
WAV2LIP_FOLDER = "./Wav2LipModel"
|
| 345 |
+
WAV2LIP_MODEL_WEIGHTS = "wav2lip_gan.pth"
|
| 346 |
+
WAV2LIP_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Fwav2lip%5Fgan%2Epth"
|
| 347 |
+
WAV2LIP_FILES_URLS = [
|
| 348 |
+
(WAV2LIP_MODEL_WEIGHTS_URL, WAV2LIP_MODEL_WEIGHTS),
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
MUSICGEN_FOLDER = "./MusicGenModel"
|
| 352 |
+
MUSICGEN_MODEL_NAME = "melody"
|
| 353 |
+
MUSICGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 354 |
+
MUSICGEN_CONFIG = "config.json"
|
| 355 |
+
MUSICGEN_SAMPLE_RATE = 32000
|
| 356 |
+
MUSICGEN_DURATION = 8
|
| 357 |
+
MUSICGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/pytorch_model.bin"
|
| 358 |
+
MUSICGEN_CONFIG_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json"
|
| 359 |
+
MUSICGEN_FILES_URLS = [
|
| 360 |
+
(MUSICGEN_MODEL_WEIGHTS_URL, MUSICGEN_MODEL_WEIGHTS),
|
| 361 |
+
(MUSICGEN_CONFIG_URL, MUSICGEN_CONFIG),
|
| 362 |
+
]
|
| 363 |
+
|
| 364 |
+
CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model"
|
| 365 |
+
CODEGEN_SPM = "spm.model"
|
| 366 |
+
|
| 367 |
+
TRANSLATION_SPM_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
| 368 |
+
TRANSLATION_SPM = "sentencepiece.bpe.model"
|
| 369 |
+
|
| 370 |
+
TEXT_TO_VIDEO_FOLDER = "./TextToVideoModel"
|
| 371 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 372 |
+
TEXT_TO_VIDEO_CONFIG = "config.json"
|
| 373 |
+
TEXT_TO_VIDEO_VOCAB = "vocab.json"
|
| 374 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/pytorch_model.bin"
|
| 375 |
+
TEXT_TO_VIDEO_CONFIG_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/config.json"
|
| 376 |
+
TEXT_TO_VIDEO_VOCAB_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/vocab.json"
|
| 377 |
+
TEXT_TO_VIDEO_FILES_URLS = [
|
| 378 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL, TEXT_TO_VIDEO_MODEL_WEIGHTS),
|
| 379 |
+
(TEXT_TO_VIDEO_CONFIG_URL, TEXT_TO_VIDEO_CONFIG),
|
| 380 |
+
(TEXT_TO_VIDEO_VOCAB_URL, TEXT_TO_VIDEO_VOCAB),
|
| 381 |
+
]
|
| 382 |
+
|
| 383 |
+
SUMMARIZATION_FOLDER = "./SummarizationModel"
|
| 384 |
+
SUMMARIZATION_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 385 |
+
SUMMARIZATION_CONFIG = "config.json"
|
| 386 |
+
SUMMARIZATION_VOCAB = "vocab.json"
|
| 387 |
+
SUMMARIZATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin"
|
| 388 |
+
SUMMARIZATION_CONFIG_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json"
|
| 389 |
+
SUMMARIZATION_VOCAB_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json"
|
| 390 |
+
SUMMARIZATION_FILES_URLS = [
|
| 391 |
+
(SUMMARIZATION_MODEL_WEIGHTS_URL, SUMMARIZATION_MODEL_WEIGHTS),
|
| 392 |
+
(SUMMARIZATION_CONFIG_URL, SUMMARIZATION_CONFIG),
|
| 393 |
+
(SUMMARIZATION_VOCAB_URL, SUMMARIZATION_VOCAB),
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
IMAGE_TO_3D_FOLDER = "./ImageTo3DModel"
|
| 397 |
+
IMAGE_TO_3D_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 398 |
+
IMAGE_TO_3D_CONFIG = "config.json"
|
| 399 |
+
IMAGE_TO_3D_MODEL_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/pytorch_model.bin"
|
| 400 |
+
IMAGE_TO_3D_CONFIG_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/config.json"
|
| 401 |
+
IMAGE_TO_3D_FILES_URLS = [
|
| 402 |
+
(IMAGE_TO_3D_MODEL_URL, IMAGE_TO_3D_MODEL_WEIGHTS),
|
| 403 |
+
(IMAGE_TO_3D_CONFIG_URL, IMAGE_TO_3D_CONFIG),
|
| 404 |
+
]
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
state_dict = None
|
| 408 |
+
enc = None
|
| 409 |
+
config = None
|
| 410 |
+
model = None
|
| 411 |
+
device = torch.device("cpu")
|
| 412 |
+
news_clf = None
|
| 413 |
+
tfidf_vectorizer = None
|
| 414 |
+
text_queue = queue.Queue()
|
| 415 |
+
categories = None
|
| 416 |
+
is_training = False
|
| 417 |
+
background_threads = []
|
| 418 |
+
feedback_queue = queue.Queue()
|
| 419 |
+
reasoning_queue = queue.Queue()
|
| 420 |
+
seen_responses = set()
|
| 421 |
+
tts_model = None
|
| 422 |
+
stt_model = None
|
| 423 |
+
sentiment_model = None
|
| 424 |
+
imagegen_model = None
|
| 425 |
+
lipsync_model = None
|
| 426 |
+
wav2lip_model = None
|
| 427 |
+
musicgen_model = None
|
| 428 |
+
translation_model = None
|
| 429 |
+
codegen_model = None
|
| 430 |
+
text_to_video_model = None
|
| 431 |
+
summarization_model = None
|
| 432 |
+
image_to_3d_model = None
|
| 433 |
+
tts_pipeline = False
|
| 434 |
+
stt_pipeline = False
|
| 435 |
+
sentiment_pipeline = False
|
| 436 |
+
imagegen_pipeline = False
|
| 437 |
+
translation_pipeline = False
|
| 438 |
+
codegen_pipeline = False
|
| 439 |
+
text_to_video_pipeline = False
|
| 440 |
+
summarization_pipeline = False
|
| 441 |
+
image_to_3d_pipeline = False
|
| 442 |
+
stt_tokenizer = None
|
| 443 |
+
stt_processor = None
|
| 444 |
+
sentiment_tokenizer = None
|
| 445 |
+
sentiment_model_instance = None
|
| 446 |
+
imagegen_vae = None
|
| 447 |
+
imagegen_unet = None
|
| 448 |
+
imagegen_scheduler = None
|
| 449 |
+
musicgen_model_instance = None
|
| 450 |
+
musicgen_tokenizer = None
|
| 451 |
+
musicgen_processor = None
|
| 452 |
+
translation_model_instance = None
|
| 453 |
+
translation_tokenizer = None
|
| 454 |
+
codegen_model_instance = None
|
| 455 |
+
codegen_tokenizer = None
|
| 456 |
+
codegen_sp = None
|
| 457 |
+
translation_sp = None
|
| 458 |
+
text_to_video_tokenizer = None
|
| 459 |
+
text_to_video_model_instance = None
|
| 460 |
+
summarization_tokenizer = None
|
| 461 |
+
summarization_model_instance = None
|
| 462 |
+
image_to_3d_config = None
|
| 463 |
+
image_to_3d_model_instance = None
|
| 464 |
+
app = Flask(__name__)
|
| 465 |
+
CORS(app)
|
| 466 |
+
|
| 467 |
+
from gpt2_pytorch import *
|
| 468 |
+
from tts_vits import *
|
| 469 |
+
from stt_wav2vec2 import *
|
| 470 |
+
from sentiment_roberta import *
|
| 471 |
+
from imagegen_vae_unet import *
|
| 472 |
+
from musicgen_torch import *
|
| 473 |
+
from translation_mbart import *
|
| 474 |
+
from codegen_torch import *
|
| 475 |
+
from text_to_video_clip4clip import *
|
| 476 |
+
from summarization_bart import *
|
| 477 |
+
from image_to_3d_openlrm import *
|
| 478 |
+
|
| 479 |
+
def download_file(url, filename):
|
| 480 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True) # Ensure directory exists
|
| 481 |
+
if not os.path.exists(filename):
|
| 482 |
+
print(f"Downloading {filename} from {url}...")
|
| 483 |
+
try:
|
| 484 |
+
wget.download(url, out=filename) # Specify output filename directly
|
| 485 |
+
print(f"Downloaded {filename} successfully.")
|
| 486 |
+
except Exception as e:
|
| 487 |
+
print(f"Error downloading {filename}: {e}")
|
| 488 |
+
|
| 489 |
+
def ensure_folder_and_files_exist(folder_path, files_urls):
|
| 490 |
+
if not os.path.exists(folder_path):
|
| 491 |
+
os.makedirs(folder_path)
|
| 492 |
+
print(f"Folder '{folder_path}' created.")
|
| 493 |
+
|
| 494 |
+
for url, filename in files_urls:
|
| 495 |
+
filepath = os.path.join(folder_path, filename)
|
| 496 |
+
download_file(url, filepath)
|
| 497 |
+
|
| 498 |
+
def ensure_single_file_exists(folder_path, file_url, filename):
|
| 499 |
+
if not os.path.exists(folder_path):
|
| 500 |
+
os.makedirs(folder_path)
|
| 501 |
+
print(f"Folder '{folder_path}' created.")
|
| 502 |
+
filepath = os.path.join(folder_path, filename)
|
| 503 |
+
download_file(file_url, filepath)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def ensure_gpt2_files_exist():
|
| 507 |
+
ensure_folder_and_files_exist(GPT2_FOLDER, GPT2_FILES_URLS)
|
| 508 |
+
|
| 509 |
+
def ensure_translation_files_exist():
|
| 510 |
+
ensure_folder_and_files_exist(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
|
| 511 |
+
ensure_single_file_exists(TRANSLATION_FOLDER, TRANSLATION_SPM_URL, TRANSLATION_SPM)
|
| 512 |
+
|
| 513 |
+
def ensure_codegen_files_exist():
|
| 514 |
+
ensure_folder_and_files_exist(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
|
| 515 |
+
ensure_single_file_exists(CODEGEN_FOLDER, CODEGEN_SPM_URL, CODEGEN_SPM)
|
| 516 |
+
|
| 517 |
+
def ensure_tts_files_exist():
|
| 518 |
+
ensure_folder_and_files_exist(TTS_FOLDER, TTS_FILES_URLS)
|
| 519 |
+
|
| 520 |
+
def ensure_stt_files_exist():
|
| 521 |
+
ensure_folder_and_files_exist(STT_FOLDER, STT_FILES_URLS)
|
| 522 |
+
|
| 523 |
+
def ensure_sentiment_files_exist():
|
| 524 |
+
ensure_folder_and_files_exist(SENTIMENT_FOLDER, SENTIMENT_FILES_URLS)
|
| 525 |
+
|
| 526 |
+
def ensure_imagegen_files_exist():
|
| 527 |
+
ensure_folder_and_files_exist(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
|
| 528 |
+
|
| 529 |
+
def ensure_lipsync_files_exist():
|
| 530 |
+
ensure_folder_and_files_exist(LIPSYNC_FOLDER, LIPSYNC_FILES_URLS)
|
| 531 |
+
|
| 532 |
+
def ensure_wav2lip_files_exist():
|
| 533 |
+
ensure_folder_and_files_exist(WAV2LIP_FOLDER, WAV2LIP_FILES_URLS)
|
| 534 |
+
|
| 535 |
+
def ensure_musicgen_files_exist():
|
| 536 |
+
ensure_folder_and_files_exist(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
|
| 537 |
+
|
| 538 |
+
def ensure_text_to_video_files_exist():
|
| 539 |
+
ensure_folder_and_files_exist(TEXT_TO_VIDEO_FOLDER, TEXT_TO_VIDEO_FILES_URLS)
|
| 540 |
+
|
| 541 |
+
def ensure_summarization_files_exist():
|
| 542 |
+
ensure_folder_and_files_exist(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
|
| 543 |
+
|
| 544 |
+
def ensure_image_to_3d_files_exist():
|
| 545 |
+
ensure_folder_and_files_exist(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
|
| 546 |
+
|
| 547 |
+
def ensure_all_model_files_exist(): # Define the function here, before it's called
|
| 548 |
+
ensure_gpt2_files_exist()
|
| 549 |
+
ensure_translation_files_exist()
|
| 550 |
+
ensure_codegen_files_exist()
|
| 551 |
+
ensure_tts_files_exist()
|
| 552 |
+
ensure_stt_files_exist()
|
| 553 |
+
ensure_sentiment_files_exist()
|
| 554 |
+
ensure_imagegen_files_exist()
|
| 555 |
+
ensure_lipsync_files_exist()
|
| 556 |
+
ensure_wav2lip_files_exist()
|
| 557 |
+
ensure_musicgen_files_exist()
|
| 558 |
+
ensure_text_to_video_files_exist()
|
| 559 |
+
ensure_summarization_files_exist()
|
| 560 |
+
ensure_image_to_3d_files_exist()
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
@app.route("/", methods=['GET'])
|
| 564 |
+
async def html_handler():
|
| 565 |
+
return html_code
|
| 566 |
+
|
| 567 |
+
@app.route("/generate_stream", methods=['GET'])
|
| 568 |
+
async def generate_stream_api():
|
| 569 |
+
text_input = request.args.get("text")
|
| 570 |
+
temperature = float(request.args.get("temp", 0.7))
|
| 571 |
+
top_k = int(request.args.get("top_k", 40))
|
| 572 |
+
top_p = float(request.args.get("top_p", 0.0))
|
| 573 |
+
reppenalty = float(request.args.get("reppenalty", 1.2))
|
| 574 |
+
return Response(generate_stream_generator(text_input, temperature, top_k, top_p, reppenalty), mimetype='text/event-stream')
|
| 575 |
+
|
| 576 |
+
@app.route("/tts", methods=['POST'])
|
| 577 |
+
def tts_api():
|
| 578 |
+
data = request.get_json()
|
| 579 |
+
text = data.get('text')
|
| 580 |
+
if not text:
|
| 581 |
+
return jsonify({"error": "Text is required"}), 400
|
| 582 |
+
output_file = text_to_speech(text)
|
| 583 |
+
if output_file == "Error generating speech.":
|
| 584 |
+
return jsonify({"error": "TTS generation failed"}), 500
|
| 585 |
+
return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
| 586 |
+
|
| 587 |
+
@app.route("/stt", methods=['POST'])
|
| 588 |
+
def stt_api():
|
| 589 |
+
if 'audio' not in request.files:
|
| 590 |
+
return jsonify({"error": "Audio file is required"}), 400
|
| 591 |
+
audio_file = request.files['audio']
|
| 592 |
+
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
| 593 |
+
audio_file.save(temp_audio_path)
|
| 594 |
+
output_file = speech_to_text(temp_audio_path)
|
| 595 |
+
os.remove(temp_audio_path)
|
| 596 |
+
if output_file == "Error transcribing audio.":
|
| 597 |
+
return jsonify({"error": "STT failed"}), 500
|
| 598 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output.txt")
|
| 599 |
+
|
| 600 |
+
@app.route("/sentiment", methods=['POST'])
|
| 601 |
+
def sentiment_api():
|
| 602 |
+
data = request.get_json()
|
| 603 |
+
text = data.get('text')
|
| 604 |
+
if not text:
|
| 605 |
+
return jsonify({"error": "Text is required"}), 400
|
| 606 |
+
output_file = analyze_sentiment(text)
|
| 607 |
+
if output_file == "Sentiment model not initialized.":
|
| 608 |
+
return jsonify({"error": "Sentiment analysis failed"}), 500
|
| 609 |
+
return jsonify(output_file)
|
| 610 |
+
|
| 611 |
+
@app.route("/imagegen", methods=['POST'])
|
| 612 |
+
def imagegen_api():
|
| 613 |
+
data = request.get_json()
|
| 614 |
+
prompt = data.get('prompt')
|
| 615 |
+
if not prompt:
|
| 616 |
+
return jsonify({"error": "Prompt is required"}), 400
|
| 617 |
+
output_file = generate_image(prompt)
|
| 618 |
+
if output_file == "Error generating image.":
|
| 619 |
+
return jsonify({"error": "Image generation failed"}), 500
|
| 620 |
+
image_io = BytesIO()
|
| 621 |
+
output_file.save(image_io, 'PNG')
|
| 622 |
+
image_io.seek(0)
|
| 623 |
+
return send_file(image_io, mimetype='image/png', as_attachment=True, download_name="output.png")
|
| 624 |
+
|
| 625 |
+
@app.route("/musicgen", methods=['POST'])
|
| 626 |
+
def musicgen_api():
|
| 627 |
+
data = request.get_json()
|
| 628 |
+
prompt = data.get('prompt')
|
| 629 |
+
if not prompt:
|
| 630 |
+
return jsonify({"error": "Prompt is required"}), 400
|
| 631 |
+
output_file = generate_music(prompt)
|
| 632 |
+
if output_file == "Error generating music.":
|
| 633 |
+
return jsonify({"error": "Music generation failed"}), 500
|
| 634 |
+
return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
| 635 |
+
|
| 636 |
+
@app.route("/translation", methods=['POST'])
|
| 637 |
+
def translation_api():
|
| 638 |
+
data = request.get_json()
|
| 639 |
+
text = data.get('text')
|
| 640 |
+
target_lang = data.get('target_lang', 'es')
|
| 641 |
+
source_lang = data.get('source_lang', 'en')
|
| 642 |
+
if not text:
|
| 643 |
+
return jsonify({"error": "Text is required"}), 400
|
| 644 |
+
output_file = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
|
| 645 |
+
if output_file == "Error during translation.":
|
| 646 |
+
return jsonify({"error": "Translation failed"}), 500
|
| 647 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_translation.txt")
|
| 648 |
+
|
| 649 |
+
@app.route("/codegen", methods=['POST'])
|
| 650 |
+
def codegen_api():
|
| 651 |
+
data = request.get_json()
|
| 652 |
+
prompt = data.get('prompt')
|
| 653 |
+
if not prompt:
|
| 654 |
+
return jsonify({"error": "Prompt is required"}), 400
|
| 655 |
+
output_file = generate_code(prompt)
|
| 656 |
+
if output_file == "Error generating code.":
|
| 657 |
+
return jsonify({"error": "Code generation failed"}), 500
|
| 658 |
+
return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
|
| 659 |
+
|
| 660 |
+
@app.route("/text_to_video", methods=['POST'])
|
| 661 |
+
def text_to_video_api():
|
| 662 |
+
data = request.get_json()
|
| 663 |
+
prompt = data.get('prompt')
|
| 664 |
+
if not prompt:
|
| 665 |
+
return jsonify({"error": "Prompt is required"}), 400
|
| 666 |
+
output_file = text_to_video(prompt)
|
| 667 |
+
if output_file == "Error generating video representation.":
|
| 668 |
+
return jsonify({"error": "Text to video failed"}), 500
|
| 669 |
+
return send_file(output_file, mimetype="application/octet-stream", as_attachment=True, download_name="output_video_representation.pt")
|
| 670 |
+
|
| 671 |
+
@app.route("/summarization", methods=['POST'])
|
| 672 |
+
def summarization_api():
|
| 673 |
+
data = request.get_json()
|
| 674 |
+
text = data.get('text')
|
| 675 |
+
if not text:
|
| 676 |
+
return jsonify({"error": "Text is required"}), 400
|
| 677 |
+
output_file = summarize_text(text)
|
| 678 |
+
if output_file == "Error during summarization.":
|
| 679 |
+
return jsonify({"error": "Summarization failed"}), 500
|
| 680 |
+
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
|
| 681 |
+
|
| 682 |
+
@app.route("/image_to_3d", methods=['POST'])
|
| 683 |
+
def image_to_3d_api():
|
| 684 |
+
if 'image' not in request.files:
|
| 685 |
+
return jsonify({"error": "Image file is required"}), 400
|
| 686 |
+
image_file = request.files['image']
|
| 687 |
+
temp_image_path = f"temp_image_{uuid.uuid4()}.png"
|
| 688 |
+
image_file.save(temp_image_path)
|
| 689 |
+
output_file = image_to_3d(temp_image_path)
|
| 690 |
+
os.remove(temp_image_path)
|
| 691 |
+
if output_file == "Error converting image to 3D.":
|
| 692 |
+
return jsonify({"error": "Image to 3D failed"}), 500
|
| 693 |
+
return send_file(output_file, mimetype="model/obj", as_attachment=True, download_name="output_3d.obj")
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
async def main():
|
| 697 |
+
global background_threads, response_queue
|
| 698 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 699 |
+
response_queue = queue.Queue()
|
| 700 |
+
|
| 701 |
+
ensure_all_model_files_exist()
|
| 702 |
+
initialize_model()
|
| 703 |
+
await initialize_sklearn()
|
| 704 |
+
initialize_tts_model()
|
| 705 |
+
initialize_stt_model()
|
| 706 |
+
initialize_sentiment_model()
|
| 707 |
+
initialize_imagegen_model()
|
| 708 |
+
ensure_lipsync_files_exist()
|
| 709 |
+
ensure_wav2lip_files_exist()
|
| 710 |
+
initialize_musicgen_model()
|
| 711 |
+
initialize_translation_model()
|
| 712 |
+
initialize_codegen_model()
|
| 713 |
+
initialize_text_to_video_model()
|
| 714 |
+
initialize_summarization_model()
|
| 715 |
+
initialize_image_to_3d_model()
|
| 716 |
+
|
| 717 |
+
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True))
|
| 718 |
+
background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
|
| 719 |
+
background_threads.append(threading.Thread(target=background_training, daemon=True))
|
| 720 |
+
for thread in background_threads:
|
| 721 |
+
thread.start()
|
| 722 |
+
|
| 723 |
+
asyncio.create_task(background_reasoning_queue())
|
| 724 |
+
|
| 725 |
+
app.run(host="127.0.0.1", port=7860, debug=False)
|
| 726 |
+
|
| 727 |
+
if __name__ == '__main__':
|
| 728 |
+
asyncio.run(main())
|
codegen_torch.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import wget
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import sentencepiece as spm
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
CODEGEN_FOLDER = "./CodeGenModel"
|
| 11 |
+
CODEGEN_MODEL_NAME = "codegen-350M-multi"
|
| 12 |
+
CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 13 |
+
CODEGEN_CONFIG = "config.json"
|
| 14 |
+
CODEGEN_VOCAB = "vocab.json"
|
| 15 |
+
CODEGEN_MERGES = "merges.txt"
|
| 16 |
+
CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin"
|
| 17 |
+
CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json"
|
| 18 |
+
CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json"
|
| 19 |
+
CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt"
|
| 20 |
+
CODEGEN_FILES_URLS = [
|
| 21 |
+
(CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS),
|
| 22 |
+
(CODEGEN_CONFIG_URL, CODEGEN_CONFIG),
|
| 23 |
+
(CODEGEN_VOCAB_URL, CODEGEN_VOCAB),
|
| 24 |
+
(CODEGEN_MERGES_URL, CODEGEN_MERGES),
|
| 25 |
+
]
|
| 26 |
+
CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model"
|
| 27 |
+
CODEGEN_SPM = "spm.model"
|
| 28 |
+
|
| 29 |
+
def ensure_codegen_files_exist():
|
| 30 |
+
os.makedirs(CODEGEN_FOLDER, exist_ok=True)
|
| 31 |
+
for url, filename in CODEGEN_FILES_URLS:
|
| 32 |
+
filepath = os.path.join(CODEGEN_FOLDER, filename)
|
| 33 |
+
if not os.path.exists(filepath):
|
| 34 |
+
wget.download(url, out=filepath)
|
| 35 |
+
filepath_spm = os.path.join(CODEGEN_FOLDER, CODEGEN_SPM)
|
| 36 |
+
if not os.path.exists(filepath_spm):
|
| 37 |
+
wget.download(CODEGEN_SPM_URL, out=filepath_spm)
|
| 38 |
+
|
| 39 |
+
class CodeGenConfig:
|
| 40 |
+
def __init__(self, vocab_size, n_positions=2048, n_ctx=2048, n_embd=1024, n_layer=24, n_head=16, n_inner=None, activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, **kwargs):
|
| 41 |
+
self.vocab_size = vocab_size
|
| 42 |
+
self.n_positions = n_positions
|
| 43 |
+
self.n_ctx = n_ctx
|
| 44 |
+
self.n_embd = n_embd
|
| 45 |
+
self.n_layer = n_layer
|
| 46 |
+
self.n_head = n_head
|
| 47 |
+
self.n_inner = n_inner
|
| 48 |
+
self.activation_function = activation_function
|
| 49 |
+
self.resid_pdrop = resid_pdrop
|
| 50 |
+
self.embd_pdrop = embd_pdrop
|
| 51 |
+
self.attn_pdrop = attn_pdrop
|
| 52 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 53 |
+
self.initializer_range = initializer_range
|
| 54 |
+
self.scale_attn_weights = scale_attn_weights
|
| 55 |
+
self.use_cache = use_cache
|
| 56 |
+
self.bos_token_id = bos_token_id
|
| 57 |
+
self.eos_token_id = eos_token_id
|
| 58 |
+
for key, value in kwargs.items():
|
| 59 |
+
setattr(self, key, value)
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def from_dict(cls, config_dict):
|
| 63 |
+
return cls(**config_dict)
|
| 64 |
+
|
| 65 |
+
class CodeGenForCausalLM(nn.Module):
|
| 66 |
+
def __init__(self, config):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.transformer = CodeGenModel(config)
|
| 69 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 70 |
+
|
| 71 |
+
def forward(self, input_ids, attention_mask=None):
|
| 72 |
+
transformer_outputs = self.transformer(input_ids, attention_mask=attention_mask)
|
| 73 |
+
logits = self.lm_head(transformer_outputs)
|
| 74 |
+
return logits
|
| 75 |
+
|
| 76 |
+
class CodeGenModel(nn.Module):
|
| 77 |
+
def __init__(self, config):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 80 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 81 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 82 |
+
self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)])
|
| 83 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 84 |
+
|
| 85 |
+
def forward(self, input_ids, attention_mask=None):
|
| 86 |
+
input_shape = input_ids.size()
|
| 87 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
| 88 |
+
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=input_ids.device)
|
| 89 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 90 |
+
inputs_embeds = self.wte(input_ids)
|
| 91 |
+
position_embeds = self.wpe(position_ids)
|
| 92 |
+
hidden_states = inputs_embeds + position_embeds
|
| 93 |
+
hidden_states = self.drop(hidden_states)
|
| 94 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 95 |
+
for block in self.h:
|
| 96 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
| 97 |
+
hidden_states = self.ln_f(hidden_states)
|
| 98 |
+
return hidden_states.view(*output_shape)
|
| 99 |
+
|
| 100 |
+
class CodeGenBlock(nn.Module):
|
| 101 |
+
def __init__(self, config):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 104 |
+
self.attn = CodeGenAttention(config)
|
| 105 |
+
self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 106 |
+
self.mlp = CodeGenMLP(config)
|
| 107 |
+
|
| 108 |
+
def forward(self, hidden_states, attention_mask=None):
|
| 109 |
+
residual = hidden_states
|
| 110 |
+
hidden_states = self.ln_1(hidden_states)
|
| 111 |
+
attn_outputs = self.attn(hidden_states, attention_mask=attention_mask)
|
| 112 |
+
hidden_states = residual + attn_outputs
|
| 113 |
+
residual = hidden_states
|
| 114 |
+
hidden_states = self.ln_2(hidden_states)
|
| 115 |
+
feedforward_hidden_states = self.mlp(hidden_states)
|
| 116 |
+
hidden_states = residual + feedforward_hidden_states
|
| 117 |
+
return hidden_states
|
| 118 |
+
|
| 119 |
+
class CodeGenMLP(nn.Module):
|
| 120 |
+
def __init__(self, config):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.c_fc = nn.Linear(config.n_embd, config.n_inner)
|
| 123 |
+
self.c_proj = nn.Linear(config.n_inner, config.n_embd)
|
| 124 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 125 |
+
|
| 126 |
+
def forward(self, hidden_states):
|
| 127 |
+
hidden_states = self.c_fc(hidden_states)
|
| 128 |
+
hidden_states = F.gelu(hidden_states)
|
| 129 |
+
hidden_states = self.c_proj(hidden_states)
|
| 130 |
+
hidden_states = self.dropout(hidden_states)
|
| 131 |
+
return hidden_states
|
| 132 |
+
|
| 133 |
+
class CodeGenAttention(nn.Module):
|
| 134 |
+
def __init__(self, config):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 137 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 138 |
+
self.n_head = config.n_head
|
| 139 |
+
self.embed_dim = config.n_embd
|
| 140 |
+
self.split_size = self.embed_dim
|
| 141 |
+
self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
|
| 142 |
+
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
| 143 |
+
self.scale_attn_weights = config.scale_attn_weights
|
| 144 |
+
self.use_cache = config.use_cache
|
| 145 |
+
self.register_buffer("bias", torch.tril(torch.ones((config.n_ctx, config.n_ctx), dtype=torch.uint8)).view((1, 1, config.n_ctx, config.n_ctx)))
|
| 146 |
+
|
| 147 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
| 148 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
| 149 |
+
if self.scale_attn_weights:
|
| 150 |
+
attn_weights = attn_weights / math.sqrt(value.size(-1))
|
| 151 |
+
|
| 152 |
+
mask = self.bias[:, :, :attn_weights.size(-2), :attn_weights.size(-1)]
|
| 153 |
+
attn_weights = torch.where(mask.bool(), attn_weights, torch.tensor(-1e4, device=attn_weights.device))
|
| 154 |
+
|
| 155 |
+
if attention_mask is not None:
|
| 156 |
+
attn_weights = attn_weights + attention_mask
|
| 157 |
+
|
| 158 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
| 159 |
+
attn_weights = self.attn_dropout(attn_weights)
|
| 160 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 161 |
+
return attn_output
|
| 162 |
+
|
| 163 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
| 164 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
| 165 |
+
tensor = tensor.view(*new_shape)
|
| 166 |
+
return tensor.permute(0, 2, 1, 3)
|
| 167 |
+
|
| 168 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
| 169 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
| 170 |
+
return tensor.view(*new_shape)
|
| 171 |
+
|
| 172 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None, past_key_value=None, use_cache=False):
|
| 173 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 174 |
+
query = self._split_heads(query, self.n_head, self.embed_dim // self.n_head)
|
| 175 |
+
key = self._split_heads(key, self.n_head, self.embed_dim // self.n_head)
|
| 176 |
+
value = self._split_heads(value, self.n_head, self.embed_dim // self.n_head)
|
| 177 |
+
if past_key_value is not None:
|
| 178 |
+
past_key, past_value = past_key_value
|
| 179 |
+
key = torch.cat((past_key, key), dim=-2)
|
| 180 |
+
value = torch.cat((past_value, value), dim=-2)
|
| 181 |
+
present_key_value = (key, value) if use_cache else None
|
| 182 |
+
attn_output = self._attn(query, key, value, attention_mask, head_mask)
|
| 183 |
+
attn_output = self._merge_heads(attn_output, self.n_head, self.embed_dim // self.n_head)
|
| 184 |
+
attn_output = self.c_proj(attn_output)
|
| 185 |
+
attn_output = self.resid_dropout(attn_output)
|
| 186 |
+
outputs = (attn_output, present_key_value)
|
| 187 |
+
return outputs[0]
|
gpt2_pytorch.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import wget
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
GPT2_FOLDER = "./GPT2"
|
| 10 |
+
MODEL_FILE = "gpt2-pytorch_model.bin"
|
| 11 |
+
ENCODER_FILE = "encoder.json"
|
| 12 |
+
VOCAB_FILE = "vocab.bpe"
|
| 13 |
+
MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"
|
| 14 |
+
ENCODER_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/encoder.json"
|
| 15 |
+
VOCAB_URL = "https://raw.githubusercontent.com/graykode/gpt-2-Pytorch/refs/heads/master/GPT2/GPT2/vocab.bpe"
|
| 16 |
+
MAX_LENGTH = 1024
|
| 17 |
+
END_OF_TEXT_TOKEN = "<|endoftext|>"
|
| 18 |
+
|
| 19 |
+
def ensure_gpt2_files_exist():
|
| 20 |
+
if not os.path.exists(os.path.join(GPT2_FOLDER, MODEL_FILE)):
|
| 21 |
+
wget.download(MODEL_URL, out=os.path.join(GPT2_FOLDER, MODEL_FILE))
|
| 22 |
+
if not os.path.exists(os.path.join(GPT2_FOLDER, ENCODER_FILE)):
|
| 23 |
+
wget.download(ENCODER_URL, out=os.path.join(GPT2_FOLDER, ENCODER_FILE))
|
| 24 |
+
if not os.path.exists(os.path.join(GPT2_FOLDER, VOCAB_FILE)):
|
| 25 |
+
wget.download(VOCAB_URL, out=os.path.join(GPT2_FOLDER, VOCAB_FILE))
|
| 26 |
+
|
| 27 |
+
class GPT2Config:
|
| 28 |
+
def __init__(self, vocab_size_or_config_json_file=50257, n_positions=MAX_LENGTH, n_ctx=MAX_LENGTH, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-5, initializer_range=0.02):
|
| 29 |
+
self.vocab_size = vocab_size_or_config_json_file
|
| 30 |
+
self.n_ctx = n_ctx
|
| 31 |
+
self.n_positions = n_positions
|
| 32 |
+
self.n_embd = n_embd
|
| 33 |
+
self.n_layer = n_layer
|
| 34 |
+
self.n_head = n_head
|
| 35 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 36 |
+
self.initializer_range = initializer_range
|
| 37 |
+
|
| 38 |
+
class GPT2LMHeadModel(nn.Module):
|
| 39 |
+
def __init__(self, config):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.transformer = GPT2Model(config)
|
| 42 |
+
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
| 43 |
+
|
| 44 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
| 45 |
+
lm_logits, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
|
| 46 |
+
return lm_logits, presents
|
| 47 |
+
|
| 48 |
+
class GPT2Model(nn.Module):
|
| 49 |
+
def __init__(self, config):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.n_layer = config.n_layer
|
| 52 |
+
self.n_embd = config.n_embd
|
| 53 |
+
self.n_vocab = config.vocab_size
|
| 54 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 55 |
+
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 56 |
+
block = Block(config.n_ctx, config, scale=True)
|
| 57 |
+
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
| 58 |
+
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 59 |
+
|
| 60 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
| 61 |
+
if past is None:
|
| 62 |
+
past_length = 0
|
| 63 |
+
past = [None] * len(self.h)
|
| 64 |
+
else:
|
| 65 |
+
past_length = past[0][0].size(-2)
|
| 66 |
+
if position_ids is None:
|
| 67 |
+
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
|
| 68 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
| 69 |
+
|
| 70 |
+
input_shape = input_ids.size()
|
| 71 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
| 72 |
+
position_ids = position_ids.view(-1, position_ids.size(-1))
|
| 73 |
+
|
| 74 |
+
inputs_embeds = self.wte(input_ids)
|
| 75 |
+
position_embeds = self.wpe(position_ids)
|
| 76 |
+
if token_type_ids is not None:
|
| 77 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
| 78 |
+
token_type_embeds = self.wte(token_type_ids)
|
| 79 |
+
else:
|
| 80 |
+
token_type_embeds = 0
|
| 81 |
+
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
| 82 |
+
presents = []
|
| 83 |
+
for block, layer_past in zip(self.h, past):
|
| 84 |
+
hidden_states, present = block(hidden_states, layer_past)
|
| 85 |
+
presents.append(present)
|
| 86 |
+
hidden_states = self.ln_f(hidden_states)
|
| 87 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 88 |
+
return hidden_states.view(*output_shape), presents
|
| 89 |
+
|
| 90 |
+
class GPT2LMHead(nn.Module):
|
| 91 |
+
def __init__(self, model_embeddings_weights, config):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.n_embd = config.n_embd
|
| 94 |
+
self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 95 |
+
self.decoder.weight = model_embeddings_weights
|
| 96 |
+
|
| 97 |
+
def forward(self, hidden_state):
|
| 98 |
+
lm_logits = self.decoder(hidden_state)
|
| 99 |
+
return lm_logits
|
| 100 |
+
|
| 101 |
+
class Block(nn.Module):
|
| 102 |
+
def __init__(self, n_ctx, config, scale=False):
|
| 103 |
+
super().__init__()
|
| 104 |
+
nx = config.n_embd
|
| 105 |
+
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 106 |
+
self.attn = Attention(nx, n_ctx, config, scale)
|
| 107 |
+
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
| 108 |
+
self.mlp = MLP(4 * nx, config)
|
| 109 |
+
|
| 110 |
+
def forward(self, x, layer_past=None):
|
| 111 |
+
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
|
| 112 |
+
x = x + a
|
| 113 |
+
m = self.mlp(self.ln_2(x))
|
| 114 |
+
x = x + m
|
| 115 |
+
return x, present
|
| 116 |
+
|
| 117 |
+
class Attention(nn.Module):
|
| 118 |
+
def __init__(self, nx, n_ctx, config, scale=False):
|
| 119 |
+
super().__init__()
|
| 120 |
+
n_state = nx
|
| 121 |
+
assert n_state % config.n_head == 0
|
| 122 |
+
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
| 123 |
+
self.n_head = config.n_head
|
| 124 |
+
self.split_size = n_state
|
| 125 |
+
self.scale = scale
|
| 126 |
+
self.c_attn = Conv1D(n_state * 3, nx)
|
| 127 |
+
self.c_proj = Conv1D(n_state, nx)
|
| 128 |
+
|
| 129 |
+
def _attn(self, q, k, v):
|
| 130 |
+
w = torch.matmul(q, k)
|
| 131 |
+
if self.scale:
|
| 132 |
+
w = w / math.sqrt(v.size(-1))
|
| 133 |
+
nd, ns = w.size(-2), w.size(-1)
|
| 134 |
+
b = self.bias[:, :, ns - nd:ns, :ns]
|
| 135 |
+
w = w * b - 1e-10 * (1 - b)
|
| 136 |
+
w = nn.Softmax(dim=-1)(w)
|
| 137 |
+
return torch.matmul(w, v)
|
| 138 |
+
|
| 139 |
+
def merge_heads(self, x):
|
| 140 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 141 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
| 142 |
+
return x.view(*new_x_shape)
|
| 143 |
+
|
| 144 |
+
def split_heads(self, x, k=False):
|
| 145 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
| 146 |
+
x = x.view(*new_x_shape)
|
| 147 |
+
if k:
|
| 148 |
+
return x.permute(0, 2, 3, 1)
|
| 149 |
+
else:
|
| 150 |
+
return x.permute(0, 2, 1, 3)
|
| 151 |
+
|
| 152 |
+
def forward(self, x, layer_past=None):
|
| 153 |
+
x = self.c_attn(x)
|
| 154 |
+
query, key, value = x.split(self.split_size, dim=2)
|
| 155 |
+
query = self.split_heads(query)
|
| 156 |
+
key = self.split_heads(key, k=True)
|
| 157 |
+
value = self.split_heads(value)
|
| 158 |
+
if layer_past is not None:
|
| 159 |
+
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]
|
| 160 |
+
key = torch.cat((past_key, key), dim=-1)
|
| 161 |
+
value = torch.cat((past_value, value), dim=-2)
|
| 162 |
+
present = torch.stack((key.transpose(-2, -1), value))
|
| 163 |
+
a = self._attn(query, key, value)
|
| 164 |
+
a = self.merge_heads(a)
|
| 165 |
+
a = self.c_proj(a)
|
| 166 |
+
return a, present
|
| 167 |
+
|
| 168 |
+
class MLP(nn.Module):
|
| 169 |
+
def __init__(self, n_state, config):
|
| 170 |
+
super().__init__()
|
| 171 |
+
nx = config.n_embd
|
| 172 |
+
self.c_fc = Conv1D(n_state, nx)
|
| 173 |
+
self.c_proj = Conv1D(nx, n_state)
|
| 174 |
+
self.act = gelu
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
h = self.act(self.c_fc(x))
|
| 178 |
+
h2 = self.c_proj(h)
|
| 179 |
+
return h2
|
| 180 |
+
|
| 181 |
+
class Conv1D(nn.Module):
|
| 182 |
+
def __init__(self, nf, nx):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.nf = nf
|
| 185 |
+
w = torch.empty(nx, nf)
|
| 186 |
+
nn.init.normal_(w, std=0.02)
|
| 187 |
+
self.weight = Parameter(w)
|
| 188 |
+
self.bias = Parameter(torch.zeros(nf))
|
| 189 |
+
|
| 190 |
+
def forward(self, x):
|
| 191 |
+
size_out = x.size()[:-1] + (self.nf,)
|
| 192 |
+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
| 193 |
+
x = x.view(*size_out)
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
class LayerNorm(nn.Module):
|
| 197 |
+
def __init__(self, hidden_size, eps=1e-12):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 200 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
| 201 |
+
self.variance_epsilon = eps
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
u = x.mean(-1, keepdim=True)
|
| 205 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
| 206 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
| 207 |
+
return self.weight * x + self.bias
|
| 208 |
+
|
| 209 |
+
def gelu(x):
|
| 210 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
image_to_3d_openlrm.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import wget
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
IMAGE_TO_3D_FOLDER = "./ImageTo3DModel"
|
| 8 |
+
IMAGE_TO_3D_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 9 |
+
IMAGE_TO_3D_CONFIG = "config.json"
|
| 10 |
+
IMAGE_TO_3D_MODEL_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/pytorch_model.bin"
|
| 11 |
+
IMAGE_TO_3D_CONFIG_URL = "https://huggingface.co/zxhezexin/openlrm-obj-base-1.1/resolve/main/config.json"
|
| 12 |
+
IMAGE_TO_3D_FILES_URLS = [
|
| 13 |
+
(IMAGE_TO_3D_MODEL_URL, IMAGE_TO_3D_MODEL_WEIGHTS),
|
| 14 |
+
(IMAGE_TO_3D_CONFIG_URL, IMAGE_TO_3D_CONFIG),
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
def ensure_image_to_3d_files_exist():
|
| 18 |
+
os.makedirs(IMAGE_TO_3D_FOLDER, exist_ok=True)
|
| 19 |
+
for url, filename in IMAGE_TO_3D_FILES_URLS:
|
| 20 |
+
filepath = os.path.join(IMAGE_TO_3D_FOLDER, filename)
|
| 21 |
+
if not os.path.exists(filepath):
|
| 22 |
+
wget.download(url, out=filepath)
|
| 23 |
+
|
| 24 |
+
class OpenLRM(nn.Module):
|
| 25 |
+
def __init__(self, num_classes):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.fc = nn.Linear(100, num_classes)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
logits = self.fc(x)
|
| 31 |
+
return logits
|
imagegen_vae_unet.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import wget
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
IMAGEGEN_FOLDER = "./ImageGenModel"
|
| 9 |
+
IMAGEGEN_MODEL_WEIGHTS = "diffusion_pytorch_model.bin"
|
| 10 |
+
IMAGEGEN_CONFIG = "config.json"
|
| 11 |
+
IMAGEGEN_MODEL_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin"
|
| 12 |
+
IMAGEGEN_CONFIG_URL = "https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json"
|
| 13 |
+
IMAGEGEN_FILES_URLS = [
|
| 14 |
+
(IMAGEGEN_MODEL_URL, IMAGEGEN_MODEL_WEIGHTS),
|
| 15 |
+
(IMAGEGEN_CONFIG_URL, IMAGEGEN_CONFIG),
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
def ensure_imagegen_files_exist():
|
| 19 |
+
os.makedirs(IMAGEGEN_FOLDER, exist_ok=True)
|
| 20 |
+
for url, filename in IMAGEGEN_FILES_URLS:
|
| 21 |
+
filepath = os.path.join(IMAGEGEN_FOLDER, filename)
|
| 22 |
+
if not os.path.exists(filepath):
|
| 23 |
+
wget.download(url, out=filepath)
|
| 24 |
+
|
| 25 |
+
class UNet2DConditionModelConfig:
|
| 26 |
+
def __init__(self, **kwargs):
|
| 27 |
+
self.sample_size = 64
|
| 28 |
+
self.layers_per_block = 2
|
| 29 |
+
self.block_out_channels = [320, 640, 1280, 1280]
|
| 30 |
+
self.downsample = [2, 2, 2, 2]
|
| 31 |
+
self.upsample = [2, 2, 2, 2]
|
| 32 |
+
self.cross_attention_dim = 768
|
| 33 |
+
self.act_fn = "silu"
|
| 34 |
+
self.norm_num_groups = 32
|
| 35 |
+
self.num_attention_heads = 8
|
| 36 |
+
for key, value in kwargs.items():
|
| 37 |
+
setattr(self, key, value)
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def from_dict(cls, config_dict):
|
| 41 |
+
return cls(**config_dict)
|
| 42 |
+
|
| 43 |
+
class UNet2DConditionModel(nn.Module):
|
| 44 |
+
def __init__(self, config: UNet2DConditionModelConfig):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.conv_in = nn.Conv2d(4, config.block_out_channels[0], kernel_size=3, padding=1)
|
| 47 |
+
self.down_blocks = nn.ModuleList([])
|
| 48 |
+
for i in range(len(config.block_out_channels)):
|
| 49 |
+
is_final_block = i == len(config.block_out_channels) - 1
|
| 50 |
+
downsample_factor = 1 if is_final_block else config.downsample[i]
|
| 51 |
+
out_channels = config.block_out_channels[i]
|
| 52 |
+
layers_per_block = config.layers_per_block
|
| 53 |
+
self.down_blocks.append(DownBlock(out_channels, layers_per_block, downsample_factor))
|
| 54 |
+
self.mid_block = MidBlock(config.block_out_channels[-1])
|
| 55 |
+
self.up_blocks = nn.ModuleList([])
|
| 56 |
+
reversed_block_out_channels = list(reversed(config.block_out_channels))
|
| 57 |
+
reversed_upsample_factors = list(reversed(config.upsample))
|
| 58 |
+
for i in range(len(config.block_out_channels)):
|
| 59 |
+
is_final_block = i == len(config.block_out_channels) - 1
|
| 60 |
+
upsample_factor = 1 if is_final_block else reversed_upsample_factors[i]
|
| 61 |
+
out_channels = reversed_block_out_channels[i]
|
| 62 |
+
layers_per_block = config.layers_per_block
|
| 63 |
+
self.up_blocks.append(UpBlock(out_channels, layers_per_block, upsample_factor))
|
| 64 |
+
self.norm_out = nn.GroupNorm(num_groups=config.norm_num_groups, num_channels=config.block_out_channels[0])
|
| 65 |
+
self.conv_norm_out = nn.Conv2d(config.block_out_channels[0], config.block_out_channels[0], kernel_size=3, padding=1)
|
| 66 |
+
self.conv_out = nn.Conv2d(config.block_out_channels[0], 4, kernel_size=3, padding=1)
|
| 67 |
+
|
| 68 |
+
def forward(self, sample: torch.FloatTensor, timestep: torch.IntTensor, encoder_hidden_states: torch.FloatTensor):
|
| 69 |
+
sample = self.conv_in(sample)
|
| 70 |
+
for down_block in self.down_blocks:
|
| 71 |
+
sample = down_block(sample)
|
| 72 |
+
sample = self.mid_block(sample)
|
| 73 |
+
for up_block in self.up_blocks:
|
| 74 |
+
sample = up_block(sample)
|
| 75 |
+
sample = self.norm_out(sample)
|
| 76 |
+
sample = F.silu(sample)
|
| 77 |
+
sample = self.conv_norm_out(sample)
|
| 78 |
+
sample = F.silu(sample)
|
| 79 |
+
sample = self.conv_out(sample)
|
| 80 |
+
return {"sample": sample}
|
| 81 |
+
|
| 82 |
+
class DownBlock(nn.Module):
|
| 83 |
+
def __init__(self, out_channels, layers_per_block, downsample_factor):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.layers = nn.ModuleList([ResnetBlock(out_channels) for _ in range(layers_per_block)])
|
| 86 |
+
if downsample_factor > 1:
|
| 87 |
+
self.downsample = Downsample2D(out_channels, downsample_factor)
|
| 88 |
+
else:
|
| 89 |
+
self.downsample = nn.Identity()
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
for layer in self.layers:
|
| 93 |
+
x = layer(x)
|
| 94 |
+
x = self.downsample(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
class UpBlock(nn.Module):
|
| 98 |
+
def __init__(self, out_channels, layers_per_block, upsample_factor):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.layers = nn.ModuleList([ResnetBlock(out_channels) for _ in range(layers_per_block)])
|
| 101 |
+
if upsample_factor > 1:
|
| 102 |
+
self.upsample = Upsample2D(out_channels, upsample_factor)
|
| 103 |
+
else:
|
| 104 |
+
self.upsample = nn.Identity()
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
for layer in self.layers:
|
| 108 |
+
x = layer(x)
|
| 109 |
+
x = self.upsample(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
class ResnetBlock(nn.Module):
|
| 113 |
+
def __init__(self, channels):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
| 116 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 117 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
| 118 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 119 |
+
self.residual_conv = nn.Conv2d(channels, channels, kernel_size=1)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
residual = x
|
| 123 |
+
x = self.norm1(x)
|
| 124 |
+
x = F.silu(x)
|
| 125 |
+
x = self.conv1(x)
|
| 126 |
+
x = self.norm2(x)
|
| 127 |
+
x = F.silu(x)
|
| 128 |
+
x = self.conv2(x)
|
| 129 |
+
return x + self.residual_conv(residual)
|
| 130 |
+
|
| 131 |
+
class MidBlock(nn.Module):
|
| 132 |
+
def __init__(self, channels):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
| 135 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 136 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=channels)
|
| 137 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
x = self.norm1(x)
|
| 141 |
+
x = F.silu(x)
|
| 142 |
+
x = self.conv1(x)
|
| 143 |
+
x = self.norm2(x)
|
| 144 |
+
x = F.silu(x)
|
| 145 |
+
x = self.conv2(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
class Downsample2D(nn.Module):
|
| 149 |
+
def __init__(self, channels, factor):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.factor = factor
|
| 152 |
+
self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=factor, padding=1)
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
return self.conv(x)
|
| 156 |
+
|
| 157 |
+
class Upsample2D(nn.Module):
|
| 158 |
+
def __init__(self, channels, factor):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.factor = factor
|
| 161 |
+
self.conv = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor)
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
return self.conv(x)
|
lipsync_wav2lip.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import wget
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
LIPSYNC_FOLDER = "./LipSyncModel"
|
| 7 |
+
LIPSYNC_MODEL_WEIGHTS = "lipsync_expert.pth"
|
| 8 |
+
LIPSYNC_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Flipsync%5Fexpert%2Epth"
|
| 9 |
+
LIPSYNC_FILES_URLS = [
|
| 10 |
+
(LIPSYNC_MODEL_WEIGHTS_URL, LIPSYNC_MODEL_WEIGHTS),
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
WAV2LIP_FOLDER = "./Wav2LipModel"
|
| 14 |
+
WAV2LIP_MODEL_WEIGHTS = "wav2lip_gan.pth"
|
| 15 |
+
WAV2LIP_MODEL_WEIGHTS_URL = "https://iiitaphyd-my.sharepoint.com/personal/radrabha_m_research_iiit_ac_in/_layouts/15/download.aspx?SourceUrl=%2Fpersonal%2Fradrabha%5Fm%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FWav2Lip%5FModels%2Fwav2lip%5Fgan%2Epth"
|
| 16 |
+
WAV2LIP_FILES_URLS = [
|
| 17 |
+
(WAV2LIP_MODEL_WEIGHTS_URL, WAV2LIP_MODEL_WEIGHTS),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
def ensure_lipsync_files_exist():
|
| 21 |
+
os.makedirs(LIPSYNC_FOLDER, exist_ok=True)
|
| 22 |
+
for url, filename in LIPSYNC_FILES_URLS:
|
| 23 |
+
filepath = os.path.join(LIPSYNC_FOLDER, filename)
|
| 24 |
+
if not os.path.exists(filepath):
|
| 25 |
+
try:
|
| 26 |
+
wget.download(url, out=filepath)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Warning: Download for {filename} failed, likely due to link restrictions. You may need to download it manually.")
|
| 29 |
+
|
| 30 |
+
def ensure_wav2lip_files_exist():
|
| 31 |
+
os.makedirs(WAV2LIP_FOLDER, exist_ok=True)
|
| 32 |
+
for url, filename in WAV2LIP_FILES_URLS:
|
| 33 |
+
filepath = os.path.join(WAV2LIP_FOLDER, filename)
|
| 34 |
+
if not os.path.exists(filepath):
|
| 35 |
+
try:
|
| 36 |
+
wget.download(url, out=filepath)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"Warning: Download for {filename} failed, likely due to link restrictions. You may need to download it manually.")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LipSyncModel(nn.Module):
|
| 42 |
+
def __init__(self, num_classes):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.fc = nn.Linear(100, num_classes)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
logits = self.fc(x)
|
| 48 |
+
return logits
|
| 49 |
+
|
| 50 |
+
class Wav2LipModel(nn.Module):
|
| 51 |
+
def __init__(self, num_classes):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.fc = nn.Linear(100, num_classes)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
logits = self.fc(x)
|
| 57 |
+
return logits
|
musicgen_torch.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchaudio
|
| 5 |
+
import wget
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
MUSICGEN_FOLDER = "./MusicGenModel"
|
| 10 |
+
MUSICGEN_MODEL_NAME = "melody"
|
| 11 |
+
MUSICGEN_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 12 |
+
MUSICGEN_CONFIG = "config.json"
|
| 13 |
+
MUSICGEN_SAMPLE_RATE = 32000
|
| 14 |
+
MUSICGEN_DURATION = 8
|
| 15 |
+
MUSICGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/pytorch_model.bin"
|
| 16 |
+
MUSICGEN_CONFIG_URL = "https://huggingface.co/facebook/musicgen-small/resolve/main/config.json"
|
| 17 |
+
MUSICGEN_FILES_URLS = [
|
| 18 |
+
(MUSICGEN_MODEL_WEIGHTS_URL, MUSICGEN_MODEL_WEIGHTS),
|
| 19 |
+
(MUSICGEN_CONFIG_URL, MUSICGEN_CONFIG),
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
def ensure_musicgen_files_exist():
|
| 23 |
+
os.makedirs(MUSICGEN_FOLDER, exist_ok=True)
|
| 24 |
+
for url, filename in MUSICGEN_FILES_URLS:
|
| 25 |
+
filepath = os.path.join(MUSICGEN_FOLDER, filename)
|
| 26 |
+
if not os.path.exists(filepath):
|
| 27 |
+
wget.download(url, out=filepath)
|
| 28 |
+
|
| 29 |
+
class MusicGenModel(nn.Module):
|
| 30 |
+
def __init__(self, num_classes):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.fc = nn.Linear(100, num_classes)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
logits = self.fc(x)
|
| 36 |
+
return logits
|
sentiment_roberta.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import wget
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
SENTIMENT_FOLDER = "./SentimentModel"
|
| 8 |
+
SENTIMENT_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 9 |
+
SENTIMENT_VOCAB = "sentiment_vocab.json"
|
| 10 |
+
SENTIMENT_CONFIG = "config.json"
|
| 11 |
+
SENTIMENT_MODEL_WEIGHTS_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/pytorch_model.bin"
|
| 12 |
+
SENTIMENT_VOCAB_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/vocab.json"
|
| 13 |
+
SENTIMENT_CONFIG_URL = "https://huggingface.co/cardiffnlp/distilroberta-base-sentiment/resolve/main/config.json"
|
| 14 |
+
SENTIMENT_FILES_URLS = [
|
| 15 |
+
(SENTIMENT_MODEL_WEIGHTS_URL, SENTIMENT_MODEL_WEIGHTS),
|
| 16 |
+
(SENTIMENT_VOCAB_URL, SENTIMENT_VOCAB),
|
| 17 |
+
(SENTIMENT_CONFIG_URL, SENTIMENT_CONFIG),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
def ensure_sentiment_files_exist():
|
| 21 |
+
os.makedirs(SENTIMENT_FOLDER, exist_ok=True)
|
| 22 |
+
for url, filename in SENTIMENT_FILES_URLS:
|
| 23 |
+
filepath = os.path.join(SENTIMENT_FOLDER, filename)
|
| 24 |
+
if not os.path.exists(filepath):
|
| 25 |
+
wget.download(url, out=filepath)
|
| 26 |
+
|
| 27 |
+
class RobertaForSequenceClassification(nn.Module):
|
| 28 |
+
def __init__(self, num_labels):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.dense = nn.Linear(768, 768)
|
| 31 |
+
self.dropout = nn.Dropout(0.1)
|
| 32 |
+
self.out_proj = nn.Linear(768, num_labels)
|
| 33 |
+
|
| 34 |
+
def forward(self, sequence_output):
|
| 35 |
+
x = sequence_output[:, 0, :]
|
| 36 |
+
x = self.dropout(x)
|
| 37 |
+
x = self.dense(x)
|
| 38 |
+
x = torch.tanh(x)
|
| 39 |
+
x = self.dropout(x)
|
| 40 |
+
x = self.out_proj(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
class RobertaModel(nn.Module):
|
| 44 |
+
def __init__(self, config):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.embeddings = RobertaEmbeddings(config)
|
| 47 |
+
self.encoder = RobertaEncoder(config)
|
| 48 |
+
|
| 49 |
+
def forward(self, input_ids, attention_mask=None):
|
| 50 |
+
embedding_output = self.embeddings(input_ids)
|
| 51 |
+
encoder_outputs = self.encoder(embedding_output, attention_mask=attention_mask)
|
| 52 |
+
return (encoder_outputs[0], )
|
| 53 |
+
|
| 54 |
+
class RobertaEmbeddings(nn.Module):
|
| 55 |
+
def __init__(self, config):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 58 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 59 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 60 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 61 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 62 |
+
self.position_ids = torch.arange(config.max_position_embeddings).expand((1, -1))
|
| 63 |
+
|
| 64 |
+
def forward(self, input_ids, token_type_ids=None, position_ids=None):
|
| 65 |
+
input_shape = input_ids.size()
|
| 66 |
+
seq_length = input_shape[1]
|
| 67 |
+
if position_ids is None:
|
| 68 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 69 |
+
if token_type_ids is None:
|
| 70 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
|
| 71 |
+
|
| 72 |
+
input_embeddings = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) + self.token_type_embeddings(token_type_ids)
|
| 73 |
+
embeddings = self.LayerNorm(embeddings)
|
| 74 |
+
embeddings = self.dropout(embeddings)
|
| 75 |
+
return embeddings
|
| 76 |
+
|
| 77 |
+
class RobertaEncoder(nn.Module):
|
| 78 |
+
def __init__(self, config):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
|
| 81 |
+
|
| 82 |
+
def forward(self, hidden_states, attention_mask=None):
|
| 83 |
+
all_encoder_layers = []
|
| 84 |
+
for layer_module in self.layer:
|
| 85 |
+
hidden_states = layer_module(hidden_states, attention_mask=attention_mask)
|
| 86 |
+
all_encoder_layers.append(hidden_states)
|
| 87 |
+
return (hidden_states, all_encoder_layers)
|
| 88 |
+
|
| 89 |
+
class RobertaLayer(nn.Module):
|
| 90 |
+
def __init__(self, config):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.attention = RobertaAttention(config)
|
| 93 |
+
self.intermediate = RobertaIntermediate(config)
|
| 94 |
+
self.output = RobertaOutput(config)
|
| 95 |
+
|
| 96 |
+
def forward(self, hidden_states, attention_mask=None):
|
| 97 |
+
attention_output = self.attention(hidden_states, attention_mask=attention_mask)
|
| 98 |
+
intermediate_output = self.intermediate(attention_output)
|
| 99 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 100 |
+
return layer_output
|
| 101 |
+
|
| 102 |
+
class RobertaAttention(nn.Module):
|
| 103 |
+
def __init__(self, config):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.self_attn = RobertaSelfAttention(config)
|
| 106 |
+
self.output = RobertaSelfOutput(config)
|
| 107 |
+
|
| 108 |
+
def forward(self, hidden_states, attention_mask=None):
|
| 109 |
+
self_output = self.self_attn(hidden_states, attention_mask=attention_mask)
|
| 110 |
+
attention_output = self.output(self_output, hidden_states)
|
| 111 |
+
return attention_output
|
| 112 |
+
|
| 113 |
+
class RobertaSelfAttention(nn.Module):
|
| 114 |
+
def __init__(self, config):
|
| 115 |
+
super().__init__()
|
| 116 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 119 |
+
f"heads ({config.num_attention_heads})"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.num_attention_heads = config.num_attention_heads
|
| 123 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 124 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 125 |
+
|
| 126 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 127 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 128 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 129 |
+
|
| 130 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 131 |
+
|
| 132 |
+
def transpose_for_scores(self, x):
|
| 133 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 134 |
+
x = x.view(*new_x_shape)
|
| 135 |
+
return x.permute(0, 2, 1, 3)
|
| 136 |
+
|
| 137 |
+
def forward(self, hidden_states, attention_mask=None):
|
| 138 |
+
mixed_query_layer = self.query(hidden_states)
|
| 139 |
+
mixed_key_layer = self.key(hidden_states)
|
| 140 |
+
mixed_value_layer = self.value(hidden_states)
|
| 141 |
+
|
| 142 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 143 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 144 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 145 |
+
|
| 146 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 147 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 148 |
+
if attention_mask is not None:
|
| 149 |
+
attention_scores = attention_scores + attention_mask
|
| 150 |
+
|
| 151 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 152 |
+
attention_probs = self.dropout(attention_probs)
|
| 153 |
+
|
| 154 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 155 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 156 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 157 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 158 |
+
return context_layer
|
| 159 |
+
|
| 160 |
+
class RobertaSelfOutput(nn.Module):
|
| 161 |
+
def __init__(self, config):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.dense = nn.Linear(config.all_head_size, config.hidden_size)
|
| 164 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 165 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 166 |
+
|
| 167 |
+
def forward(self, hidden_states, input_tensor):
|
| 168 |
+
hidden_states = self.dense(hidden_states)
|
| 169 |
+
hidden_states = self.dropout(hidden_states)
|
| 170 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 171 |
+
return hidden_states
|
| 172 |
+
|
| 173 |
+
class RobertaIntermediate(nn.Module):
|
| 174 |
+
def __init__(self, config):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 177 |
+
self.intermediate_act_fn = gelu
|
| 178 |
+
|
| 179 |
+
def forward(self, hidden_states):
|
| 180 |
+
hidden_states = self.dense(hidden_states)
|
| 181 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 182 |
+
return hidden_states
|
| 183 |
+
|
| 184 |
+
class RobertaOutput(nn.Module):
|
| 185 |
+
def __init__(self, config):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 188 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 189 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 190 |
+
|
| 191 |
+
def forward(self, hidden_states, input_tensor):
|
| 192 |
+
hidden_states = self.dense(hidden_states)
|
| 193 |
+
hidden_states = self.dropout(hidden_states)
|
| 194 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 195 |
+
return hidden_states
|
stt_wav2vec2.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchaudio
|
| 5 |
+
import wget
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
STT_FOLDER = "./STTModel"
|
| 10 |
+
STT_MODEL_NAME = "wav2vec2"
|
| 11 |
+
STT_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 12 |
+
STT_CONFIG = "config.json"
|
| 13 |
+
STT_VOCAB = "vocab.json"
|
| 14 |
+
STT_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/pytorch_model.bin"
|
| 15 |
+
STT_CONFIG_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json"
|
| 16 |
+
STT_VOCAB_URL = "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
|
| 17 |
+
STT_FILES_URLS = [
|
| 18 |
+
(STT_MODEL_WEIGHTS_URL, STT_MODEL_WEIGHTS),
|
| 19 |
+
(STT_CONFIG_URL, STT_CONFIG),
|
| 20 |
+
(STT_VOCAB_URL, STT_VOCAB),
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
def ensure_stt_files_exist():
|
| 24 |
+
os.makedirs(STT_FOLDER, exist_ok=True)
|
| 25 |
+
for url, filename in STT_FILES_URLS:
|
| 26 |
+
filepath = os.path.join(STT_FOLDER, filename)
|
| 27 |
+
if not os.path.exists(filepath):
|
| 28 |
+
wget.download(url, out=filepath)
|
| 29 |
+
|
| 30 |
+
class Wav2Vec2ForCTC(nn.Module):
|
| 31 |
+
def __init__(self, num_classes):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.conv1 = nn.Conv1d(1, 16, kernel_size=5, stride=2, padding=2)
|
| 34 |
+
self.relu1 = nn.ReLU()
|
| 35 |
+
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
|
| 36 |
+
self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
|
| 37 |
+
self.relu2 = nn.ReLU()
|
| 38 |
+
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
|
| 39 |
+
self.fc = nn.Linear(32 * 39 * 40, num_classes) # Adjusted input size
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
x = self.pool1(self.relu1(self.conv1(x)))
|
| 43 |
+
x = self.pool2(self.relu2(self.conv2(x)))
|
| 44 |
+
x = x.view(x.size(0), -1)
|
| 45 |
+
logits = self.fc(x)
|
| 46 |
+
return logits
|
summarization_bart.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import wget
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
SUMMARIZATION_FOLDER = "./SummarizationModel"
|
| 8 |
+
SUMMARIZATION_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 9 |
+
SUMMARIZATION_CONFIG = "config.json"
|
| 10 |
+
SUMMARIZATION_VOCAB = "vocab.json"
|
| 11 |
+
SUMMARIZATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/pytorch_model.bin"
|
| 12 |
+
SUMMARIZATION_CONFIG_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/config.json"
|
| 13 |
+
SUMMARIZATION_VOCAB_URL = "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json"
|
| 14 |
+
SUMMARIZATION_FILES_URLS = [
|
| 15 |
+
(SUMMARIZATION_MODEL_WEIGHTS_URL, SUMMARIZATION_MODEL_WEIGHTS),
|
| 16 |
+
(SUMMARIZATION_CONFIG_URL, SUMMARIZATION_CONFIG),
|
| 17 |
+
(SUMMARIZATION_VOCAB_URL, SUMMARIZATION_VOCAB),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
def ensure_summarization_files_exist():
|
| 21 |
+
os.makedirs(SUMMARIZATION_FOLDER, exist_ok=True)
|
| 22 |
+
for url, filename in SUMMARIZATION_FILES_URLS:
|
| 23 |
+
filepath = os.path.join(SUMMARIZATION_FOLDER, filename)
|
| 24 |
+
if not os.path.exists(filepath):
|
| 25 |
+
wget.download(url, out=filepath)
|
| 26 |
+
|
| 27 |
+
class BartForConditionalGeneration(nn.Module):
|
| 28 |
+
def __init__(self, num_classes):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.fc = nn.Linear(100, num_classes)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
logits = self.fc(x)
|
| 34 |
+
return logits
|
text_to_video_clip4clip.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import wget
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
TEXT_TO_VIDEO_FOLDER = "./TextToVideoModel"
|
| 8 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 9 |
+
TEXT_TO_VIDEO_CONFIG = "config.json"
|
| 10 |
+
TEXT_TO_VIDEO_VOCAB = "vocab.json"
|
| 11 |
+
TEXT_TO_VIDEO_MODEL_WEIGHTS_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/pytorch_model.bin"
|
| 12 |
+
TEXT_TO_VIDEO_CONFIG_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/config.json"
|
| 13 |
+
TEXT_TO_VIDEO_VOCAB_URL = "https://huggingface.co/Searchium-ai/clip4clip-webvid150k/resolve/main/vocab.json"
|
| 14 |
+
TEXT_TO_VIDEO_FILES_URLS = [
|
| 15 |
+
(TEXT_TO_VIDEO_MODEL_WEIGHTS_URL, TEXT_TO_VIDEO_MODEL_WEIGHTS),
|
| 16 |
+
(TEXT_TO_VIDEO_CONFIG_URL, TEXT_TO_VIDEO_CONFIG),
|
| 17 |
+
(TEXT_TO_VIDEO_VOCAB_URL, TEXT_TO_VIDEO_VOCAB),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
def ensure_text_to_video_files_exist():
|
| 21 |
+
os.makedirs(TEXT_TO_VIDEO_FOLDER, exist_ok=True)
|
| 22 |
+
for url, filename in TEXT_TO_VIDEO_FILES_URLS:
|
| 23 |
+
filepath = os.path.join(TEXT_TO_VIDEO_FOLDER, filename)
|
| 24 |
+
if not os.path.exists(filepath):
|
| 25 |
+
wget.download(url, out=filepath)
|
| 26 |
+
|
| 27 |
+
class Clip4ClipModel(nn.Module):
|
| 28 |
+
def __init__(self, num_classes):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.fc = nn.Linear(100, num_classes)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
logits = self.fc(x)
|
| 34 |
+
return logits
|
translation_mbart.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import wget
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import sentencepiece as spm
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
TRANSLATION_FOLDER = "./TranslationModel"
|
| 10 |
+
TRANSLATION_MODEL_WEIGHTS_FILE = "pytorch_model.bin"
|
| 11 |
+
TRANSLATION_MODEL_CONFIG_FILE = "config.json"
|
| 12 |
+
TRANSLATION_MODEL_VOCAB_FILE = "sentencepiece.bpe.model"
|
| 13 |
+
TRANSLATION_MODEL_WEIGHTS_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/pytorch_model.bin"
|
| 14 |
+
TRANSLATION_MODEL_CONFIG_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/config.json"
|
| 15 |
+
TRANSLATION_MODEL_VOCAB_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
| 16 |
+
TRANSLATION_MODEL_FILES_URLS = [
|
| 17 |
+
(TRANSLATION_MODEL_WEIGHTS_URL, TRANSLATION_MODEL_WEIGHTS_FILE),
|
| 18 |
+
(TRANSLATION_MODEL_CONFIG_URL, TRANSLATION_MODEL_CONFIG_FILE),
|
| 19 |
+
(TRANSLATION_MODEL_VOCAB_URL, TRANSLATION_MODEL_VOCAB_FILE),
|
| 20 |
+
]
|
| 21 |
+
TRANSLATION_SPM_URL = "https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt/resolve/main/sentencepiece.bpe.model"
|
| 22 |
+
TRANSLATION_SPM = "sentencepiece.bpe.model"
|
| 23 |
+
|
| 24 |
+
def ensure_translation_files_exist():
|
| 25 |
+
os.makedirs(TRANSLATION_FOLDER, exist_ok=True)
|
| 26 |
+
for url, filename in TRANSLATION_MODEL_FILES_URLS:
|
| 27 |
+
filepath = os.path.join(TRANSLATION_FOLDER, filename)
|
| 28 |
+
if not os.path.exists(filepath):
|
| 29 |
+
wget.download(url, out=filepath)
|
| 30 |
+
filepath_spm = os.path.join(TRANSLATION_FOLDER, TRANSLATION_SPM)
|
| 31 |
+
if not os.path.exists(filepath_spm):
|
| 32 |
+
wget.download(TRANSLATION_SPM_URL, out=filepath_spm)
|
| 33 |
+
|
| 34 |
+
class MBartConfig:
|
| 35 |
+
def __init__(self, vocab_size, hidden_size=1024, num_hidden_layers=12, num_attention_heads=16, intermediate_size=4096, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, layer_norm_eps=1e-05, initializer_range=0.02, pad_token_id=1, bos_token_id=0, eos_token_id=2, n_positions=1024, n_ctx=1024, decoder_layers=12, decoder_attention_heads=16, decoder_ffn_dim=4096, encoder_layers=12, encoder_attention_heads=16, encoder_ffn_dim=4096, **kwargs):
|
| 36 |
+
self.vocab_size = vocab_size
|
| 37 |
+
self.hidden_size = hidden_size
|
| 38 |
+
self.num_hidden_layers = num_hidden_layers
|
| 39 |
+
self.num_attention_heads = num_attention_heads
|
| 40 |
+
self.intermediate_size = intermediate_size
|
| 41 |
+
self.hidden_act = hidden_act
|
| 42 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 43 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 44 |
+
self.layer_norm_eps = layer_norm_eps
|
| 45 |
+
self.initializer_range = initializer_range
|
| 46 |
+
self.pad_token_id = pad_token_id
|
| 47 |
+
self.bos_token_id = bos_token_id
|
| 48 |
+
self.eos_token_id = eos_token_id
|
| 49 |
+
self.n_positions = n_positions
|
| 50 |
+
self.n_ctx = n_ctx
|
| 51 |
+
self.decoder_layers = decoder_layers
|
| 52 |
+
self.decoder_attention_heads = decoder_attention_heads
|
| 53 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
| 54 |
+
self.encoder_layers = encoder_layers
|
| 55 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 56 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
| 57 |
+
for key, value in kwargs.items():
|
| 58 |
+
setattr(self, key, value)
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def from_dict(cls, config_dict):
|
| 62 |
+
return cls(**config_dict)
|
| 63 |
+
|
| 64 |
+
class MBartForConditionalGeneration(nn.Module):
|
| 65 |
+
def __init__(self, config):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.model = MBartModel(config)
|
| 68 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
| 69 |
+
self.final_logits_bias = nn.Parameter(torch.zeros((1, config.vocab_size)))
|
| 70 |
+
|
| 71 |
+
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None):
|
| 72 |
+
outputs = self.model(input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask)
|
| 73 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
| 74 |
+
return lm_logits
|
| 75 |
+
|
| 76 |
+
class MBartModel(nn.Module):
|
| 77 |
+
def __init__(self, config):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.encoder = MBartEncoder(config)
|
| 80 |
+
self.decoder = MBartDecoder(config)
|
| 81 |
+
|
| 82 |
+
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None):
|
| 83 |
+
encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
|
| 84 |
+
decoder_outputs = self.decoder(decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
| 85 |
+
return decoder_outputs
|
| 86 |
+
|
| 87 |
+
class MBartEncoder(nn.Module):
|
| 88 |
+
def __init__(self, config):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 91 |
+
self.embed_positions = MBartSinusoidalPositionalEmbedding(config.hidden_size, config.pad_token_id)
|
| 92 |
+
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
| 93 |
+
self.layernorm_embedding = nn.LayerNorm(config.hidden_size)
|
| 94 |
+
|
| 95 |
+
def forward(self, input_ids, attention_mask=None):
|
| 96 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 97 |
+
position_embeddings = self.embed_positions(input_ids)
|
| 98 |
+
embeddings = inputs_embeds + position_embeddings
|
| 99 |
+
embeddings = self.layernorm_embedding(embeddings)
|
| 100 |
+
encoder_states = embeddings
|
| 101 |
+
all_encoder_layers = []
|
| 102 |
+
for layer_module in self.layers:
|
| 103 |
+
encoder_states = layer_module(encoder_states, encoder_padding_mask=attention_mask)
|
| 104 |
+
all_encoder_layers.append(encoder_states)
|
| 105 |
+
return (encoder_states, all_encoder_layers)
|
| 106 |
+
|
| 107 |
+
class MBartDecoder(nn.Module):
|
| 108 |
+
def __init__(self, config):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 111 |
+
self.embed_positions = MBartSinusoidalPositionalEmbedding(config.hidden_size, config.pad_token_id)
|
| 112 |
+
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
| 113 |
+
self.layernorm_embedding = nn.LayerNorm(config.hidden_size)
|
| 114 |
+
|
| 115 |
+
def forward(self, decoder_input_ids, encoder_outputs, decoder_attention_mask=None):
|
| 116 |
+
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
| 117 |
+
position_embeddings = self.embed_positions(decoder_input_ids)
|
| 118 |
+
embeddings = inputs_embeds + position_embeddings
|
| 119 |
+
embeddings = self.layernorm_embedding(embeddings)
|
| 120 |
+
decoder_states = embeddings
|
| 121 |
+
all_decoder_layers = []
|
| 122 |
+
all_cross_attention_layers = []
|
| 123 |
+
for layer_module in self.layers:
|
| 124 |
+
decoder_states, cross_attn_weights = layer_module(decoder_states, encoder_outputs[0], decoder_padding_mask=decoder_attention_mask, encoder_padding_mask=encoder_outputs[0])
|
| 125 |
+
all_decoder_layers.append(decoder_states)
|
| 126 |
+
all_cross_attention_layers.append(cross_attn_weights)
|
| 127 |
+
return (decoder_states, all_decoder_layers, all_cross_attention_layers)
|
| 128 |
+
|
| 129 |
+
class MBartSinusoidalPositionalEmbedding(nn.Module):
|
| 130 |
+
def __init__(self, embedding_dim, padding_idx):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.embedding_dim = embedding_dim
|
| 133 |
+
self.padding_idx = padding_idx
|
| 134 |
+
|
| 135 |
+
def forward(self, input_ids):
|
| 136 |
+
seq_len = input_ids.size(1)
|
| 137 |
+
positions = torch.arange(self.padding_idx + 1, seq_len + self.padding_idx + 1, dtype=torch.long, device=input_ids.device)
|
| 138 |
+
return self.get_embedding(positions)
|
| 139 |
+
|
| 140 |
+
def get_embedding(self, positions):
|
| 141 |
+
half_dim = self.embedding_dim // 2
|
| 142 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 143 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=positions.device) * -emb)
|
| 144 |
+
emb = torch.outer(positions.float(), emb)
|
| 145 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 146 |
+
if self.embedding_dim % 2 == 1:
|
| 147 |
+
emb = F.pad(emb, (0, 1, 0, 0))
|
| 148 |
+
return emb
|
| 149 |
+
|
| 150 |
+
class MBartEncoderLayer(nn.Module):
|
| 151 |
+
def __init__(self, config):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.self_attn = MBartAttention(config, embed_dim=config.hidden_size, num_heads=config.encoder_attention_heads)
|
| 154 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 155 |
+
self.fc1 = nn.Linear(config.hidden_size, config.encoder_ffn_dim)
|
| 156 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, config.hidden_size)
|
| 157 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 158 |
+
|
| 159 |
+
def forward(self, hidden_states, encoder_padding_mask=None):
|
| 160 |
+
residual = hidden_states
|
| 161 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 162 |
+
hidden_states, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attention_mask=encoder_padding_mask)
|
| 163 |
+
hidden_states = residual + hidden_states
|
| 164 |
+
residual = hidden_states
|
| 165 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 166 |
+
hidden_states = self.fc2(F.relu(self.fc1(hidden_states)))
|
| 167 |
+
hidden_states = residual + hidden_states
|
| 168 |
+
return hidden_states
|
| 169 |
+
|
| 170 |
+
class MBartDecoderLayer(nn.Module):
|
| 171 |
+
def __init__(self, config):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.self_attn = MBartAttention(config, embed_dim=config.hidden_size, num_heads=config.decoder_attention_heads)
|
| 174 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 175 |
+
self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 176 |
+
self.encoder_attn = MBartAttention(config, embed_dim=config.hidden_size, num_heads=config.decoder_attention_heads)
|
| 177 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 178 |
+
self.fc1 = nn.Linear(config.hidden_size, config.decoder_ffn_dim)
|
| 179 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, config.hidden_size)
|
| 180 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 181 |
+
|
| 182 |
+
def forward(self, hidden_states, encoder_hidden_states, decoder_padding_mask=None, encoder_padding_mask=None):
|
| 183 |
+
residual = hidden_states
|
| 184 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 185 |
+
hidden_states, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attention_mask=decoder_padding_mask)
|
| 186 |
+
hidden_states = residual + hidden_states
|
| 187 |
+
residual = hidden_states
|
| 188 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 189 |
+
hidden_states, cross_attn_weights = self.encoder_attn(hidden_states, encoder_hidden_states, encoder_hidden_states, attention_mask=encoder_padding_mask)
|
| 190 |
+
hidden_states = residual + hidden_states
|
| 191 |
+
residual = hidden_states
|
| 192 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 193 |
+
hidden_states = self.fc2(F.relu(self.fc1(hidden_states)))
|
| 194 |
+
hidden_states = residual + hidden_states
|
| 195 |
+
return hidden_states, cross_attn_weights
|
| 196 |
+
|
| 197 |
+
class MBartAttention(nn.Module):
|
| 198 |
+
def __init__(self, config, embed_dim, num_heads):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.embed_dim = embed_dim
|
| 201 |
+
self.num_heads = num_heads
|
| 202 |
+
self.head_dim = embed_dim // num_heads
|
| 203 |
+
self.scaling = self.head_dim ** -0.5
|
| 204 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 205 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 206 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 207 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 208 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 209 |
+
|
| 210 |
+
def _shape(self, tensor, seq_len, bsz):
|
| 211 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 212 |
+
|
| 213 |
+
def forward(self, query, key, value, attention_mask=None):
|
| 214 |
+
bsz, tgt_len, _ = query.size()
|
| 215 |
+
bsz, src_len, _ = key.size()
|
| 216 |
+
query = self.q_proj(query)
|
| 217 |
+
key = self.k_proj(key)
|
| 218 |
+
value = self.v_proj(value)
|
| 219 |
+
query = self._shape(query, tgt_len, bsz)
|
| 220 |
+
key = self._shape(key, src_len, bsz)
|
| 221 |
+
value = self._shape(value, src_len, bsz)
|
| 222 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * self.scaling
|
| 223 |
+
|
| 224 |
+
if attention_mask is not None:
|
| 225 |
+
attention_mask = attention_mask.float().masked_fill(attention_mask == 0, float('-inf')).masked_fill(attention_mask == 1, float(0.0))
|
| 226 |
+
attn_weights = attn_weights + attention_mask
|
| 227 |
+
|
| 228 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
| 229 |
+
attn_weights = self.dropout(attn_weights)
|
| 230 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 231 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, self.embed_dim)
|
| 232 |
+
attn_output = self.out_proj(attn_output)
|
| 233 |
+
return attn_output, attn_weights
|
| 234 |
+
|
| 235 |
+
class MBartTokenizer:
|
| 236 |
+
def __init__(self, sentencepiece_processor):
|
| 237 |
+
self.sp = sentencepiece_processor
|
| 238 |
+
self.pad_token = "<pad>"
|
| 239 |
+
self.bos_token = "<s>"
|
| 240 |
+
self.eos_token = "</s>"
|
| 241 |
+
self.pad_token_id = 1
|
| 242 |
+
self.bos_token_id = 0
|
| 243 |
+
self.eos_token_id = 2
|
| 244 |
+
self.model_max_length = 1024
|
| 245 |
+
|
| 246 |
+
def __call__(self, text, return_tensors="pt", padding=True, truncation=True, max_length=None, src_lang="en_XX", tgt_lang="es_XX", **kwargs):
|
| 247 |
+
max_length = max_length if max_length is not None else self.model_max_length
|
| 248 |
+
self.sp.SetEncodeExtraOptions("bos:<s>,eos:</s>")
|
| 249 |
+
input_ids = self.sp.EncodeAsIds(f"{src_lang} {text}")
|
| 250 |
+
if truncation and len(input_ids) > max_length:
|
| 251 |
+
input_ids = input_ids[:max_length]
|
| 252 |
+
if padding:
|
| 253 |
+
input_ids += [self.pad_token_id] * (max_length - len(input_ids))
|
| 254 |
+
if return_tensors == "pt":
|
| 255 |
+
return {"input_ids": torch.tensor([input_ids]), "attention_mask": torch.ones(len(input_ids)).unsqueeze(0)}
|
| 256 |
+
return input_ids
|
| 257 |
+
|
| 258 |
+
def batch_decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, target_lang="es_XX"):
|
| 259 |
+
decoded_texts = []
|
| 260 |
+
for ids in token_ids:
|
| 261 |
+
text = self.sp.DecodeIds(list(ids))
|
| 262 |
+
if skip_special_tokens:
|
| 263 |
+
text = re.sub(r'(<s>|</s>|<pad>)', '', text).strip()
|
| 264 |
+
if clean_up_tokenization_spaces:
|
| 265 |
+
text = text.replace(' ', ' ').strip()
|
| 266 |
+
decoded_texts.append(text.replace(f"{target_lang} ", ""))
|
| 267 |
+
return decoded_texts
|
tts_vits.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchaudio
|
| 5 |
+
import wget
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
TTS_FOLDER = "./TTSModel"
|
| 10 |
+
TTS_MODEL_NAME = "vits"
|
| 11 |
+
TTS_MODEL_CONFIG = "config.json"
|
| 12 |
+
TTS_MODEL_WEIGHTS = "pytorch_model.bin"
|
| 13 |
+
TTS_VOCAB = "vocab.json"
|
| 14 |
+
TTS_CONFIG_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/config.json"
|
| 15 |
+
TTS_MODEL_WEIGHTS_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/pytorch_model.bin"
|
| 16 |
+
TTS_VOCAB_URL = "https://huggingface.co/kakao-enterprise/vits-vctk/resolve/main/vocab.json"
|
| 17 |
+
TTS_FILES_URLS = [
|
| 18 |
+
(TTS_CONFIG_URL, TTS_MODEL_CONFIG),
|
| 19 |
+
(TTS_MODEL_WEIGHTS_URL, TTS_MODEL_WEIGHTS),
|
| 20 |
+
(TTS_VOCAB_URL, TTS_VOCAB),
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
def ensure_tts_files_exist():
|
| 24 |
+
os.makedirs(TTS_FOLDER, exist_ok=True)
|
| 25 |
+
for url, filename in TTS_FILES_URLS:
|
| 26 |
+
filepath = os.path.join(TTS_FOLDER, filename)
|
| 27 |
+
if not os.path.exists(filepath):
|
| 28 |
+
wget.download(url, out=filepath)
|
| 29 |
+
|
| 30 |
+
class VITS(nn.Module):
|
| 31 |
+
def __init__(self, spec_channels, segment_size, num_speakers, num_languages, num_symbols):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.spec_channels = spec_channels
|
| 34 |
+
self.segment_size = segment_size
|
| 35 |
+
self.num_speakers = num_speakers
|
| 36 |
+
self.num_languages = num_languages
|
| 37 |
+
self.num_symbols = num_symbols
|
| 38 |
+
self.embedding = nn.Embedding(num_symbols, 192)
|
| 39 |
+
self.decoder = Generator(spec_channels)
|
| 40 |
+
|
| 41 |
+
def forward(self, text):
|
| 42 |
+
x = self.embedding(text)
|
| 43 |
+
audio = self.decoder(x)
|
| 44 |
+
return audio
|
| 45 |
+
|
| 46 |
+
class Generator(nn.Module):
|
| 47 |
+
def __init__(self, spec_channels):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.spec_channels = spec_channels
|
| 50 |
+
self.initial_conv = nn.ConvTranspose2d(192, spec_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
|
| 51 |
+
self.final_conv = nn.Conv2d(spec_channels, 1, kernel_size=(7, 7), padding=(3, 3))
|
| 52 |
+
|
| 53 |
+
def forward(self, encoder_outputs):
|
| 54 |
+
x = encoder_outputs.unsqueeze(2)
|
| 55 |
+
x = self.initial_conv(x)
|
| 56 |
+
x = self.final_conv(x)
|
| 57 |
+
return x.squeeze(1)
|