Spaces:
Sleeping
Sleeping
| // Global state | |
| let isModelLoaded = false; | |
| // DOM Elements | |
| const els = { | |
| chat: document.getElementById("chat"), | |
| promptForm: document.getElementById("prompt-form"), | |
| promptInput: document.getElementById("prompt"), | |
| loadBtn: document.getElementById("btn-load"), | |
| testStreamBtn: document.getElementById("btn-test-stream"), | |
| status: document.getElementById("load-status"), | |
| sidebar: document.getElementById("sidebar"), | |
| sidebarToggle: document.getElementById("btn-toggle-sidebar"), | |
| chatList: document.getElementById("chat-list"), | |
| newChatBtn: document.getElementById("new-chat"), | |
| sendBtn: document.getElementById("btn-send"), | |
| steps: document.getElementById("steps"), | |
| block_size: document.getElementById("block_size"), | |
| max_new_tokens: document.getElementById("max_new_tokens"), | |
| parallel_blocks: document.getElementById("parallel_blocks"), | |
| stepsValue: document.getElementById("steps-value"), | |
| block_sizeValue: document.getElementById("block_size-value"), | |
| max_new_tokensValue: document.getElementById("max_new_tokens-value"), | |
| parallel_blocksValue: document.getElementById("parallel_blocks-value"), | |
| }; | |
| // Update slider values | |
| els.steps.addEventListener("input", () => { | |
| els.stepsValue.textContent = els.steps.value; | |
| }); | |
| els.block_size.addEventListener("input", () => { | |
| els.block_sizeValue.textContent = els.block_size.value; | |
| }); | |
| els.max_new_tokens.addEventListener("input", () => { | |
| els.max_new_tokensValue.textContent = els.max_new_tokens.value; | |
| }); | |
| els.parallel_blocks.addEventListener("input", () => { | |
| els.parallel_blocksValue.textContent = els.parallel_blocks.value; | |
| }); | |
| // --- Logic --- | |
| async function checkLoadStatus() { | |
| try { | |
| const res = await fetch("/api/load", { | |
| method: "POST", | |
| headers: { "Content-Type": "application/json" }, | |
| body: JSON.stringify({ check_only: true }), | |
| }); | |
| if (res.ok) { | |
| const data = await res.json(); | |
| if (data.loaded) { | |
| isModelLoaded = true; | |
| els.status.textContent = "Ready"; | |
| els.status.className = "text-sm text-green-600 font-medium"; | |
| els.loadBtn.style.display = 'none'; | |
| } | |
| } | |
| } catch (e) { | |
| console.log("Model check failed:", e); | |
| } | |
| } | |
| els.loadBtn.addEventListener("click", async () => { | |
| els.loadBtn.disabled = true; | |
| els.status.textContent = "Loading Model (this may take time)..."; | |
| els.status.className = "text-sm text-yellow-600 font-medium"; | |
| try { | |
| const res = await fetch("/api/load", { | |
| method: "POST", | |
| headers: { "Content-Type": "application/json" }, | |
| body: JSON.stringify({ check_only: false }), | |
| }); | |
| const data = await res.json(); | |
| if (res.ok) { | |
| isModelLoaded = true; | |
| els.status.textContent = "Model Loaded"; | |
| els.status.className = "text-sm text-green-600 font-medium"; | |
| els.loadBtn.style.display = 'none'; | |
| } else { | |
| throw new Error(data.message || "Load failed"); | |
| } | |
| } catch (e) { | |
| els.status.textContent = "Error Loading"; | |
| els.status.className = "text-sm text-red-500"; | |
| alert("Error: " + e.message); | |
| } finally { | |
| els.loadBtn.disabled = false; | |
| } | |
| }); | |
| els.promptForm.addEventListener("submit", async (e) => { | |
| e.preventDefault(); | |
| const text = els.promptInput.value.trim(); | |
| if (!text) return; | |
| // UI Updates | |
| addMessage("user", text); | |
| els.promptInput.value = ""; | |
| // Create Assistant Bubble | |
| const assistantBubble = addMessage("assistant", ""); | |
| const contentPre = assistantBubble.querySelector(".content"); | |
| const textContent = contentPre.querySelector(".text-content"); | |
| const visualizationDiv = document.createElement("div"); | |
| visualizationDiv.className = "visualization mb-2 font-mono text-xs"; | |
| // Loading spinner (SVG) | |
| const spinner = document.createElement("div"); | |
| spinner.className = "flex items-center gap-2 text-slate-400"; | |
| spinner.innerHTML = ` | |
| <svg class="animate-spin h-4 w-4" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"> | |
| <circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle> | |
| <path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8v4a4 4 0 00-4 4H4z"></path> | |
| </svg> | |
| <span class="text-xs">Generating...</span> | |
| `; | |
| visualizationDiv.appendChild(spinner); | |
| contentPre.insertBefore(visualizationDiv, textContent); | |
| // Disable send button | |
| els.sendBtn.disabled = true; | |
| els.sendBtn.textContent = "Generating..."; | |
| els.promptInput.disabled = true; | |
| // Generate Request with Streaming | |
| try { | |
| const res = await fetch("/api/generate-stream", { | |
| method: "POST", | |
| headers: { "Content-Type": "application/json" }, | |
| body: JSON.stringify({ | |
| instruction: text, | |
| steps: parseInt(els.steps.value), | |
| block_size: parseInt(els.block_size.value), | |
| max_new_tokens: parseInt(els.max_new_tokens.value), | |
| parallel_blocks: parseInt(els.parallel_blocks.value), | |
| }), | |
| }); | |
| if (!res.ok) { | |
| throw new Error(`Server Error ${res.status}`); | |
| } | |
| const reader = res.body.getReader(); | |
| const decoder = new TextDecoder(); | |
| let buffer = ""; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| buffer += decoder.decode(value, { stream: true }); | |
| const lines = buffer.split("\n"); | |
| buffer = lines.pop(); // Keep incomplete line in buffer | |
| for (const line of lines) { | |
| if (line.startsWith("data: ")) { | |
| const jsonStr = line.slice(6); | |
| if (jsonStr.trim()) { | |
| try { | |
| const data = JSON.parse(jsonStr); | |
| handleStreamEvent(data, visualizationDiv, textContent); | |
| } catch (e) { | |
| console.error("Failed to parse SSE data:", e); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } catch (error) { | |
| if (textContent) textContent.textContent = `Error: ${error.message}`; | |
| } finally { | |
| els.sendBtn.disabled = false; | |
| els.sendBtn.textContent = "Send"; | |
| els.promptInput.disabled = false; | |
| } | |
| }); | |
| function handleStreamEvent(data, visualizationDiv, textContent) { | |
| if (data.type === "start") { | |
| textContent.textContent = ""; | |
| } else if (data.type === "update") { | |
| // Render visualization | |
| renderVisualization(data.data, visualizationDiv); | |
| scrollToBottom(); | |
| } else if (data.type === "complete") { | |
| // Clear visualization and show final response | |
| visualizationDiv.innerHTML = ""; | |
| textContent.textContent = data.response || "No response"; | |
| scrollToBottom(); | |
| } else if (data.type === "error") { | |
| textContent.textContent = `Error: ${data.error}`; | |
| } | |
| } | |
| function renderVisualization(vizData, container) { | |
| // Clear previous content | |
| container.innerHTML = ""; | |
| // Show context | |
| const contextDiv = document.createElement("div"); | |
| contextDiv.className = "text-slate-600 mb-1"; | |
| contextDiv.textContent = vizData.context; | |
| container.appendChild(contextDiv); | |
| // Show blocks | |
| const blocksDiv = document.createElement("div"); | |
| blocksDiv.classList.add("flex", "flex-wrap", "gap-0"); | |
| const blockColors = ["text-green-600", "text-cyan-600", "text-yellow-600", "text-purple-600"]; | |
| vizData.blocks.forEach((block, blockIdx) => { | |
| const blockSpan = document.createElement("span"); | |
| blockSpan.className = blockColors[blockIdx % blockColors.length]; | |
| block.tokens.forEach((token) => { | |
| if (token.type === "masked") { | |
| const maskedSpan = document.createElement("span"); | |
| maskedSpan.className = blockColors[blockIdx % blockColors.length]; | |
| maskedSpan.innerText = token.text + " "; | |
| blockSpan.appendChild(maskedSpan); | |
| } else { | |
| const textNode = document.createTextNode(token.text); | |
| blockSpan.appendChild(textNode); | |
| } | |
| }); | |
| blocksDiv.appendChild(blockSpan); | |
| }); | |
| container.appendChild(blocksDiv); | |
| // Add legend if multiple blocks | |
| if (vizData.num_blocks > 1) { | |
| const legendDiv = document.createElement("div"); | |
| legendDiv.className = "text-xs text-slate-500 mt-1"; | |
| const legends = []; | |
| for (let i = 0; i < vizData.num_blocks; i++) { | |
| legends.push(`Block ${i + 1}`); | |
| } | |
| legendDiv.textContent = `Generating: ${legends.join(" | ")}`; | |
| container.appendChild(legendDiv); | |
| } | |
| } | |
| // --- UI Helpers --- | |
| function addMessage(role, text) { | |
| const wrapper = document.createElement("div"); | |
| wrapper.className = "mb-6 max-w-[100%] flex flex-col"; | |
| const bubble = document.createElement("div"); | |
| const isUser = role === "user"; | |
| bubble.className = isUser ? "self-end bg-slate-900 text-white p-4 rounded-2xl rounded-tr-sm max-w-[85%]" : "self-start bg-white border border-gray-200 text-slate-800 p-4 rounded-2xl rounded-tl-sm max-w-[65%] whitespace-pre-wrap overflow-x-auto shadow-sm flex flex-wrap"; | |
| // Main Content container that holds the response text | |
| const pre = document.createElement("div"); | |
| pre.className = "content whitespace-pre-wrap font-sans text-sm leading-relaxed"; | |
| // The actual text content | |
| const textSpan = document.createElement("span"); | |
| textSpan.className = "text-content"; | |
| textSpan.textContent = text; | |
| pre.appendChild(textSpan); | |
| bubble.appendChild(pre); | |
| wrapper.appendChild(bubble); | |
| els.chat.appendChild(wrapper); | |
| scrollToBottom(); | |
| // Hide welcome screen | |
| const welcome = document.getElementById("welcome"); | |
| if (welcome) { | |
| welcome.classList.add("hidden"); | |
| } | |
| els.chat.classList.remove("hidden"); | |
| return bubble; | |
| } | |
| function scrollToBottom() { | |
| els.chat.scrollTop = els.chat.scrollHeight; | |
| } | |
| // Sidebar Toggle | |
| els.sidebarToggle.addEventListener("click", () => { | |
| els.sidebar.classList.toggle("-translate-x-full"); | |
| }); | |
| // New Chat Button | |
| els.newChatBtn.addEventListener("click", () => { | |
| // Clear chat | |
| els.chat.innerHTML = ""; | |
| els.chat.classList.add("hidden"); | |
| // Show welcome screen | |
| const welcome = document.getElementById("welcome"); | |
| if (welcome) { | |
| welcome.classList.remove("hidden"); | |
| } | |
| // Clear input | |
| els.promptInput.value = ""; | |
| }); | |
| // Initialize | |
| (async () => { | |
| await checkLoadStatus(); | |
| if (!isModelLoaded) { | |
| els.loadBtn.disabled = true; | |
| els.status.textContent = "Loading Model (this may take time)..."; | |
| els.status.className = "text-sm text-yellow-600 font-medium"; | |
| try { | |
| const res = await fetch("/api/load", { | |
| method: "POST", | |
| headers: { "Content-Type": "application/json" }, | |
| body: JSON.stringify({ check_only: false }), | |
| }); | |
| const data = await res.json(); | |
| if (res.ok) { | |
| isModelLoaded = true; | |
| els.status.textContent = "Model Loaded"; | |
| els.status.className = "text-sm text-green-600 font-medium"; | |
| els.loadBtn.style.display = 'none'; | |
| } else { | |
| throw new Error(data.message || "Load failed"); | |
| } | |
| } catch (e) { | |
| els.status.textContent = "Error Loading"; | |
| els.status.className = "text-sm text-red-500"; | |
| } finally { | |
| els.loadBtn.disabled = false; | |
| } | |
| } | |
| })(); | |
| els.chat.classList.add("hidden"); | |