thejagstudio's picture
Upload 10 files
486838c verified
// Global state
let isModelLoaded = false;
// DOM Elements
const els = {
chat: document.getElementById("chat"),
promptForm: document.getElementById("prompt-form"),
promptInput: document.getElementById("prompt"),
loadBtn: document.getElementById("btn-load"),
testStreamBtn: document.getElementById("btn-test-stream"),
status: document.getElementById("load-status"),
sidebar: document.getElementById("sidebar"),
sidebarToggle: document.getElementById("btn-toggle-sidebar"),
chatList: document.getElementById("chat-list"),
newChatBtn: document.getElementById("new-chat"),
sendBtn: document.getElementById("btn-send"),
steps: document.getElementById("steps"),
block_size: document.getElementById("block_size"),
max_new_tokens: document.getElementById("max_new_tokens"),
parallel_blocks: document.getElementById("parallel_blocks"),
stepsValue: document.getElementById("steps-value"),
block_sizeValue: document.getElementById("block_size-value"),
max_new_tokensValue: document.getElementById("max_new_tokens-value"),
parallel_blocksValue: document.getElementById("parallel_blocks-value"),
};
// Update slider values
els.steps.addEventListener("input", () => {
els.stepsValue.textContent = els.steps.value;
});
els.block_size.addEventListener("input", () => {
els.block_sizeValue.textContent = els.block_size.value;
});
els.max_new_tokens.addEventListener("input", () => {
els.max_new_tokensValue.textContent = els.max_new_tokens.value;
});
els.parallel_blocks.addEventListener("input", () => {
els.parallel_blocksValue.textContent = els.parallel_blocks.value;
});
// --- Logic ---
async function checkLoadStatus() {
try {
const res = await fetch("/api/load", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ check_only: true }),
});
if (res.ok) {
const data = await res.json();
if (data.loaded) {
isModelLoaded = true;
els.status.textContent = "Ready";
els.status.className = "text-sm text-green-600 font-medium";
els.loadBtn.style.display = 'none';
}
}
} catch (e) {
console.log("Model check failed:", e);
}
}
els.loadBtn.addEventListener("click", async () => {
els.loadBtn.disabled = true;
els.status.textContent = "Loading Model (this may take time)...";
els.status.className = "text-sm text-yellow-600 font-medium";
try {
const res = await fetch("/api/load", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ check_only: false }),
});
const data = await res.json();
if (res.ok) {
isModelLoaded = true;
els.status.textContent = "Model Loaded";
els.status.className = "text-sm text-green-600 font-medium";
els.loadBtn.style.display = 'none';
} else {
throw new Error(data.message || "Load failed");
}
} catch (e) {
els.status.textContent = "Error Loading";
els.status.className = "text-sm text-red-500";
alert("Error: " + e.message);
} finally {
els.loadBtn.disabled = false;
}
});
els.promptForm.addEventListener("submit", async (e) => {
e.preventDefault();
const text = els.promptInput.value.trim();
if (!text) return;
// UI Updates
addMessage("user", text);
els.promptInput.value = "";
// Create Assistant Bubble
const assistantBubble = addMessage("assistant", "");
const contentPre = assistantBubble.querySelector(".content");
const textContent = contentPre.querySelector(".text-content");
const visualizationDiv = document.createElement("div");
visualizationDiv.className = "visualization mb-2 font-mono text-xs";
// Loading spinner (SVG)
const spinner = document.createElement("div");
spinner.className = "flex items-center gap-2 text-slate-400";
spinner.innerHTML = `
<svg class="animate-spin h-4 w-4" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8v4a4 4 0 00-4 4H4z"></path>
</svg>
<span class="text-xs">Generating...</span>
`;
visualizationDiv.appendChild(spinner);
contentPre.insertBefore(visualizationDiv, textContent);
// Disable send button
els.sendBtn.disabled = true;
els.sendBtn.textContent = "Generating...";
els.promptInput.disabled = true;
// Generate Request with Streaming
try {
const res = await fetch("/api/generate-stream", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
instruction: text,
steps: parseInt(els.steps.value),
block_size: parseInt(els.block_size.value),
max_new_tokens: parseInt(els.max_new_tokens.value),
parallel_blocks: parseInt(els.parallel_blocks.value),
}),
});
if (!res.ok) {
throw new Error(`Server Error ${res.status}`);
}
const reader = res.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop(); // Keep incomplete line in buffer
for (const line of lines) {
if (line.startsWith("data: ")) {
const jsonStr = line.slice(6);
if (jsonStr.trim()) {
try {
const data = JSON.parse(jsonStr);
handleStreamEvent(data, visualizationDiv, textContent);
} catch (e) {
console.error("Failed to parse SSE data:", e);
}
}
}
}
}
} catch (error) {
if (textContent) textContent.textContent = `Error: ${error.message}`;
} finally {
els.sendBtn.disabled = false;
els.sendBtn.textContent = "Send";
els.promptInput.disabled = false;
}
});
function handleStreamEvent(data, visualizationDiv, textContent) {
if (data.type === "start") {
textContent.textContent = "";
} else if (data.type === "update") {
// Render visualization
renderVisualization(data.data, visualizationDiv);
scrollToBottom();
} else if (data.type === "complete") {
// Clear visualization and show final response
visualizationDiv.innerHTML = "";
textContent.textContent = data.response || "No response";
scrollToBottom();
} else if (data.type === "error") {
textContent.textContent = `Error: ${data.error}`;
}
}
function renderVisualization(vizData, container) {
// Clear previous content
container.innerHTML = "";
// Show context
const contextDiv = document.createElement("div");
contextDiv.className = "text-slate-600 mb-1";
contextDiv.textContent = vizData.context;
container.appendChild(contextDiv);
// Show blocks
const blocksDiv = document.createElement("div");
blocksDiv.classList.add("flex", "flex-wrap", "gap-0");
const blockColors = ["text-green-600", "text-cyan-600", "text-yellow-600", "text-purple-600"];
vizData.blocks.forEach((block, blockIdx) => {
const blockSpan = document.createElement("span");
blockSpan.className = blockColors[blockIdx % blockColors.length];
block.tokens.forEach((token) => {
if (token.type === "masked") {
const maskedSpan = document.createElement("span");
maskedSpan.className = blockColors[blockIdx % blockColors.length];
maskedSpan.innerText = token.text + " ";
blockSpan.appendChild(maskedSpan);
} else {
const textNode = document.createTextNode(token.text);
blockSpan.appendChild(textNode);
}
});
blocksDiv.appendChild(blockSpan);
});
container.appendChild(blocksDiv);
// Add legend if multiple blocks
if (vizData.num_blocks > 1) {
const legendDiv = document.createElement("div");
legendDiv.className = "text-xs text-slate-500 mt-1";
const legends = [];
for (let i = 0; i < vizData.num_blocks; i++) {
legends.push(`Block ${i + 1}`);
}
legendDiv.textContent = `Generating: ${legends.join(" | ")}`;
container.appendChild(legendDiv);
}
}
// --- UI Helpers ---
function addMessage(role, text) {
const wrapper = document.createElement("div");
wrapper.className = "mb-6 max-w-[100%] flex flex-col";
const bubble = document.createElement("div");
const isUser = role === "user";
bubble.className = isUser ? "self-end bg-slate-900 text-white p-4 rounded-2xl rounded-tr-sm max-w-[85%]" : "self-start bg-white border border-gray-200 text-slate-800 p-4 rounded-2xl rounded-tl-sm max-w-[65%] whitespace-pre-wrap overflow-x-auto shadow-sm flex flex-wrap";
// Main Content container that holds the response text
const pre = document.createElement("div");
pre.className = "content whitespace-pre-wrap font-sans text-sm leading-relaxed";
// The actual text content
const textSpan = document.createElement("span");
textSpan.className = "text-content";
textSpan.textContent = text;
pre.appendChild(textSpan);
bubble.appendChild(pre);
wrapper.appendChild(bubble);
els.chat.appendChild(wrapper);
scrollToBottom();
// Hide welcome screen
const welcome = document.getElementById("welcome");
if (welcome) {
welcome.classList.add("hidden");
}
els.chat.classList.remove("hidden");
return bubble;
}
function scrollToBottom() {
els.chat.scrollTop = els.chat.scrollHeight;
}
// Sidebar Toggle
els.sidebarToggle.addEventListener("click", () => {
els.sidebar.classList.toggle("-translate-x-full");
});
// New Chat Button
els.newChatBtn.addEventListener("click", () => {
// Clear chat
els.chat.innerHTML = "";
els.chat.classList.add("hidden");
// Show welcome screen
const welcome = document.getElementById("welcome");
if (welcome) {
welcome.classList.remove("hidden");
}
// Clear input
els.promptInput.value = "";
});
// Initialize
(async () => {
await checkLoadStatus();
if (!isModelLoaded) {
els.loadBtn.disabled = true;
els.status.textContent = "Loading Model (this may take time)...";
els.status.className = "text-sm text-yellow-600 font-medium";
try {
const res = await fetch("/api/load", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ check_only: false }),
});
const data = await res.json();
if (res.ok) {
isModelLoaded = true;
els.status.textContent = "Model Loaded";
els.status.className = "text-sm text-green-600 font-medium";
els.loadBtn.style.display = 'none';
} else {
throw new Error(data.message || "Load failed");
}
} catch (e) {
els.status.textContent = "Error Loading";
els.status.className = "text-sm text-red-500";
} finally {
els.loadBtn.disabled = false;
}
}
})();
els.chat.classList.add("hidden");