| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use crate::core::Id; |
|
|
| |
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| pub enum Role { |
| |
| System, |
| |
| User, |
| |
| Assistant, |
| |
| Tool, |
| |
| Context, |
| } |
|
|
| impl Role { |
| pub fn as_str(&self) -> &'static str { |
| match self { |
| Role::System => "system", |
| Role::User => "user", |
| Role::Assistant => "assistant", |
| Role::Tool => "tool", |
| Role::Context => "context", |
| } |
| } |
|
|
| pub fn from_str(s: &str) -> Option<Self> { |
| match s.to_lowercase().as_str() { |
| "system" => Some(Role::System), |
| "user" => Some(Role::User), |
| "assistant" => Some(Role::Assistant), |
| "tool" | "function" => Some(Role::Tool), |
| "context" | "retrieved" => Some(Role::Context), |
| _ => None, |
| } |
| } |
|
|
| fn to_byte(&self) -> u8 { |
| match self { |
| Role::System => 0, |
| Role::User => 1, |
| Role::Assistant => 2, |
| Role::Tool => 3, |
| Role::Context => 4, |
| } |
| } |
|
|
| fn from_byte(b: u8) -> Option<Self> { |
| match b { |
| 0 => Some(Role::System), |
| 1 => Some(Role::User), |
| 2 => Some(Role::Assistant), |
| 3 => Some(Role::Tool), |
| 4 => Some(Role::Context), |
| _ => None, |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[derive(Debug, Clone)] |
| pub struct CompressedKV { |
| |
| pub model_id: String, |
|
|
| |
| pub num_layers: u32, |
|
|
| |
| pub num_heads: u32, |
|
|
| |
| pub head_dim: u32, |
|
|
| |
| pub seq_len: u32, |
|
|
| |
| pub quantization: String, |
|
|
| |
| |
| |
| pub data: Vec<u8>, |
| } |
|
|
| impl CompressedKV { |
| |
| pub fn size_bytes(&self) -> usize { |
| self.data.len() |
| } |
|
|
| |
| pub fn placeholder(model_id: &str) -> Self { |
| Self { |
| model_id: model_id.to_string(), |
| num_layers: 0, |
| num_heads: 0, |
| head_dim: 0, |
| seq_len: 0, |
| quantization: "none".to_string(), |
| data: vec![], |
| } |
| } |
|
|
| |
| pub fn to_bytes(&self) -> Vec<u8> { |
| let mut bytes = Vec::new(); |
|
|
| |
| let model_bytes = self.model_id.as_bytes(); |
| bytes.extend_from_slice(&(model_bytes.len() as u32).to_le_bytes()); |
| bytes.extend_from_slice(model_bytes); |
|
|
| |
| bytes.extend_from_slice(&self.num_layers.to_le_bytes()); |
| bytes.extend_from_slice(&self.num_heads.to_le_bytes()); |
| bytes.extend_from_slice(&self.head_dim.to_le_bytes()); |
| bytes.extend_from_slice(&self.seq_len.to_le_bytes()); |
|
|
| |
| let quant_bytes = self.quantization.as_bytes(); |
| bytes.extend_from_slice(&(quant_bytes.len() as u32).to_le_bytes()); |
| bytes.extend_from_slice(quant_bytes); |
|
|
| |
| bytes.extend_from_slice(&(self.data.len() as u64).to_le_bytes()); |
| bytes.extend_from_slice(&self.data); |
|
|
| bytes |
| } |
|
|
| |
| pub fn from_bytes(data: &[u8]) -> Option<(Self, usize)> { |
| let mut offset = 0; |
|
|
| |
| if data.len() < offset + 4 { |
| return None; |
| } |
| let model_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize; |
| offset += 4; |
|
|
| if data.len() < offset + model_len { |
| return None; |
| } |
| let model_id = String::from_utf8(data[offset..offset + model_len].to_vec()).ok()?; |
| offset += model_len; |
|
|
| |
| if data.len() < offset + 16 { |
| return None; |
| } |
| let num_layers = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| offset += 4; |
| let num_heads = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| offset += 4; |
| let head_dim = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| offset += 4; |
| let seq_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?); |
| offset += 4; |
|
|
| |
| if data.len() < offset + 4 { |
| return None; |
| } |
| let quant_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize; |
| offset += 4; |
|
|
| if data.len() < offset + quant_len { |
| return None; |
| } |
| let quantization = String::from_utf8(data[offset..offset + quant_len].to_vec()).ok()?; |
| offset += quant_len; |
|
|
| |
| if data.len() < offset + 8 { |
| return None; |
| } |
| let data_len = u64::from_le_bytes(data[offset..offset + 8].try_into().ok()?) as usize; |
| offset += 8; |
|
|
| if data.len() < offset + data_len { |
| return None; |
| } |
| let kv_data = data[offset..offset + data_len].to_vec(); |
| offset += data_len; |
|
|
| Some(( |
| Self { |
| model_id, |
| num_layers, |
| num_heads, |
| head_dim, |
| seq_len, |
| quantization, |
| data: kv_data, |
| }, |
| offset, |
| )) |
| } |
| } |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct AttentionState { |
| |
| pub id: Id, |
|
|
| |
| pub timestamp_ms: u64, |
|
|
| |
| pub role: Role, |
|
|
| |
| pub text: String, |
|
|
| |
| pub embedding: Vec<f32>, |
|
|
| |
| pub kv_cache: Option<CompressedKV>, |
|
|
| |
| pub metadata: std::collections::HashMap<String, String>, |
| } |
|
|
| impl AttentionState { |
| |
| pub fn new(role: Role, text: String, embedding: Vec<f32>) -> Self { |
| Self { |
| id: Id::now(), |
| timestamp_ms: std::time::SystemTime::now() |
| .duration_since(std::time::UNIX_EPOCH) |
| .unwrap() |
| .as_millis() as u64, |
| role, |
| text, |
| embedding, |
| kv_cache: None, |
| metadata: std::collections::HashMap::new(), |
| } |
| } |
|
|
| |
| pub fn with_kv_cache(mut self, kv: CompressedKV) -> Self { |
| self.kv_cache = Some(kv); |
| self |
| } |
|
|
| |
| pub fn with_metadata(mut self, key: &str, value: &str) -> Self { |
| self.metadata.insert(key.to_string(), value.to_string()); |
| self |
| } |
|
|
| |
| pub fn size_bytes(&self) -> usize { |
| 16 + |
| 8 + |
| 1 + |
| self.text.len() + |
| self.embedding.len() * 4 + |
| self.kv_cache.as_ref().map(|kv| kv.size_bytes()).unwrap_or(0) + |
| self.metadata.iter().map(|(k, v)| k.len() + v.len() + 8).sum::<usize>() |
| } |
|
|
| |
| pub fn to_bytes(&self) -> Vec<u8> { |
| let mut bytes = Vec::new(); |
|
|
| |
| bytes.extend_from_slice(b"ATTN"); |
| bytes.extend_from_slice(&1u32.to_le_bytes()); |
|
|
| |
| bytes.extend_from_slice(self.id.as_bytes()); |
|
|
| |
| bytes.extend_from_slice(&self.timestamp_ms.to_le_bytes()); |
|
|
| |
| bytes.push(self.role.to_byte()); |
|
|
| |
| let text_bytes = self.text.as_bytes(); |
| bytes.extend_from_slice(&(text_bytes.len() as u32).to_le_bytes()); |
| bytes.extend_from_slice(text_bytes); |
|
|
| |
| bytes.extend_from_slice(&(self.embedding.len() as u32).to_le_bytes()); |
| for &v in &self.embedding { |
| bytes.extend_from_slice(&v.to_le_bytes()); |
| } |
|
|
| |
| if let Some(ref kv) = self.kv_cache { |
| bytes.push(1); |
| let kv_bytes = kv.to_bytes(); |
| bytes.extend_from_slice(&(kv_bytes.len() as u64).to_le_bytes()); |
| bytes.extend_from_slice(&kv_bytes); |
| } else { |
| bytes.push(0); |
| } |
|
|
| |
| bytes.extend_from_slice(&(self.metadata.len() as u32).to_le_bytes()); |
| for (key, value) in &self.metadata { |
| let key_bytes = key.as_bytes(); |
| let value_bytes = value.as_bytes(); |
| bytes.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes()); |
| bytes.extend_from_slice(key_bytes); |
| bytes.extend_from_slice(&(value_bytes.len() as u32).to_le_bytes()); |
| bytes.extend_from_slice(value_bytes); |
| } |
|
|
| bytes |
| } |
|
|
| |
| pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> { |
| let mut offset = 0; |
|
|
| |
| if data.len() < 8 { |
| return Err(AttentionError::InvalidFormat("Too short".into())); |
| } |
| if &data[0..4] != b"ATTN" { |
| return Err(AttentionError::InvalidMagic); |
| } |
| offset += 4; |
|
|
| |
| let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); |
| if version != 1 { |
| return Err(AttentionError::UnsupportedVersion(version)); |
| } |
| offset += 4; |
|
|
| |
| if data.len() < offset + 16 { |
| return Err(AttentionError::InvalidFormat("Missing ID".into())); |
| } |
| let mut id_bytes = [0u8; 16]; |
| id_bytes.copy_from_slice(&data[offset..offset + 16]); |
| let id = Id::from_bytes(id_bytes); |
| offset += 16; |
|
|
| |
| if data.len() < offset + 8 { |
| return Err(AttentionError::InvalidFormat("Missing timestamp".into())); |
| } |
| let timestamp_ms = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()); |
| offset += 8; |
|
|
| |
| if data.len() < offset + 1 { |
| return Err(AttentionError::InvalidFormat("Missing role".into())); |
| } |
| let role = Role::from_byte(data[offset]) |
| .ok_or_else(|| AttentionError::InvalidFormat("Invalid role".into()))?; |
| offset += 1; |
|
|
| |
| if data.len() < offset + 4 { |
| return Err(AttentionError::InvalidFormat("Missing text length".into())); |
| } |
| let text_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| offset += 4; |
|
|
| if data.len() < offset + text_len { |
| return Err(AttentionError::InvalidFormat("Text truncated".into())); |
| } |
| let text = String::from_utf8(data[offset..offset + text_len].to_vec()) |
| .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in text".into()))?; |
| offset += text_len; |
|
|
| |
| if data.len() < offset + 4 { |
| return Err(AttentionError::InvalidFormat("Missing embedding length".into())); |
| } |
| let emb_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| offset += 4; |
|
|
| if data.len() < offset + emb_len * 4 { |
| return Err(AttentionError::InvalidFormat("Embedding truncated".into())); |
| } |
| let mut embedding = Vec::with_capacity(emb_len); |
| for _ in 0..emb_len { |
| embedding.push(f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap())); |
| offset += 4; |
| } |
|
|
| |
| if data.len() < offset + 1 { |
| return Err(AttentionError::InvalidFormat("Missing KV flag".into())); |
| } |
| let has_kv = data[offset] != 0; |
| offset += 1; |
|
|
| let kv_cache = if has_kv { |
| if data.len() < offset + 8 { |
| return Err(AttentionError::InvalidFormat("Missing KV length".into())); |
| } |
| let kv_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; |
| offset += 8; |
|
|
| if data.len() < offset + kv_len { |
| return Err(AttentionError::InvalidFormat("KV data truncated".into())); |
| } |
| let (kv, _) = CompressedKV::from_bytes(&data[offset..offset + kv_len]) |
| .ok_or_else(|| AttentionError::InvalidFormat("Invalid KV cache".into()))?; |
| offset += kv_len; |
| Some(kv) |
| } else { |
| None |
| }; |
|
|
| |
| if data.len() < offset + 4 { |
| return Err(AttentionError::InvalidFormat("Missing metadata count".into())); |
| } |
| let meta_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| offset += 4; |
|
|
| let mut metadata = std::collections::HashMap::new(); |
| for _ in 0..meta_count { |
| |
| if data.len() < offset + 4 { |
| return Err(AttentionError::InvalidFormat("Missing key length".into())); |
| } |
| let key_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| offset += 4; |
|
|
| if data.len() < offset + key_len { |
| return Err(AttentionError::InvalidFormat("Key truncated".into())); |
| } |
| let key = String::from_utf8(data[offset..offset + key_len].to_vec()) |
| .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in key".into()))?; |
| offset += key_len; |
|
|
| |
| if data.len() < offset + 4 { |
| return Err(AttentionError::InvalidFormat("Missing value length".into())); |
| } |
| let value_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| offset += 4; |
|
|
| if data.len() < offset + value_len { |
| return Err(AttentionError::InvalidFormat("Value truncated".into())); |
| } |
| let value = String::from_utf8(data[offset..offset + value_len].to_vec()) |
| .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in value".into()))?; |
| offset += value_len; |
|
|
| metadata.insert(key, value); |
| } |
|
|
| Ok(Self { |
| id, |
| timestamp_ms, |
| role, |
| text, |
| embedding, |
| kv_cache, |
| metadata, |
| }) |
| } |
| } |
|
|
| |
| #[derive(Debug, Clone)] |
| pub enum AttentionError { |
| InvalidMagic, |
| UnsupportedVersion(u32), |
| InvalidFormat(String), |
| } |
|
|
| impl std::fmt::Display for AttentionError { |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| match self { |
| AttentionError::InvalidMagic => write!(f, "Invalid magic bytes"), |
| AttentionError::UnsupportedVersion(v) => write!(f, "Unsupported version: {}", v), |
| AttentionError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg), |
| } |
| } |
| } |
|
|
| impl std::error::Error for AttentionError {} |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct AttentionBatch { |
| |
| pub states: Vec<AttentionState>, |
|
|
| |
| pub session_id: Option<Id>, |
|
|
| |
| pub document_id: Option<Id>, |
| } |
|
|
| impl AttentionBatch { |
| pub fn new() -> Self { |
| Self { |
| states: Vec::new(), |
| session_id: None, |
| document_id: None, |
| } |
| } |
|
|
| pub fn with_session(mut self, session_id: Id) -> Self { |
| self.session_id = Some(session_id); |
| self |
| } |
|
|
| pub fn with_document(mut self, document_id: Id) -> Self { |
| self.document_id = Some(document_id); |
| self |
| } |
|
|
| pub fn add(&mut self, state: AttentionState) { |
| self.states.push(state); |
| } |
|
|
| |
| pub fn size_bytes(&self) -> usize { |
| self.states.iter().map(|s| s.size_bytes()).sum() |
| } |
|
|
| |
| pub fn to_bytes(&self) -> Vec<u8> { |
| let mut bytes = Vec::new(); |
|
|
| |
| bytes.extend_from_slice(b"ATNB"); |
| bytes.extend_from_slice(&1u32.to_le_bytes()); |
|
|
| |
| if let Some(sid) = self.session_id { |
| bytes.push(1); |
| bytes.extend_from_slice(sid.as_bytes()); |
| } else { |
| bytes.push(0); |
| } |
|
|
| |
| if let Some(did) = self.document_id { |
| bytes.push(1); |
| bytes.extend_from_slice(did.as_bytes()); |
| } else { |
| bytes.push(0); |
| } |
|
|
| |
| bytes.extend_from_slice(&(self.states.len() as u32).to_le_bytes()); |
|
|
| |
| for state in &self.states { |
| let state_bytes = state.to_bytes(); |
| bytes.extend_from_slice(&(state_bytes.len() as u64).to_le_bytes()); |
| bytes.extend_from_slice(&state_bytes); |
| } |
|
|
| bytes |
| } |
|
|
| |
| pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> { |
| let mut offset = 0; |
|
|
| |
| if data.len() < 8 { |
| return Err(AttentionError::InvalidFormat("Too short".into())); |
| } |
| if &data[0..4] != b"ATNB" { |
| return Err(AttentionError::InvalidMagic); |
| } |
| offset += 4; |
|
|
| |
| let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()); |
| if version != 1 { |
| return Err(AttentionError::UnsupportedVersion(version)); |
| } |
| offset += 4; |
|
|
| |
| if data.len() < offset + 1 { |
| return Err(AttentionError::InvalidFormat("Missing session flag".into())); |
| } |
| let has_session = data[offset] != 0; |
| offset += 1; |
|
|
| let session_id = if has_session { |
| if data.len() < offset + 16 { |
| return Err(AttentionError::InvalidFormat("Missing session ID".into())); |
| } |
| let mut id_bytes = [0u8; 16]; |
| id_bytes.copy_from_slice(&data[offset..offset + 16]); |
| offset += 16; |
| Some(Id::from_bytes(id_bytes)) |
| } else { |
| None |
| }; |
|
|
| |
| if data.len() < offset + 1 { |
| return Err(AttentionError::InvalidFormat("Missing document flag".into())); |
| } |
| let has_document = data[offset] != 0; |
| offset += 1; |
|
|
| let document_id = if has_document { |
| if data.len() < offset + 16 { |
| return Err(AttentionError::InvalidFormat("Missing document ID".into())); |
| } |
| let mut id_bytes = [0u8; 16]; |
| id_bytes.copy_from_slice(&data[offset..offset + 16]); |
| offset += 16; |
| Some(Id::from_bytes(id_bytes)) |
| } else { |
| None |
| }; |
|
|
| |
| if data.len() < offset + 4 { |
| return Err(AttentionError::InvalidFormat("Missing state count".into())); |
| } |
| let state_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize; |
| offset += 4; |
|
|
| |
| let mut states = Vec::with_capacity(state_count); |
| for _ in 0..state_count { |
| if data.len() < offset + 8 { |
| return Err(AttentionError::InvalidFormat("Missing state length".into())); |
| } |
| let state_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; |
| offset += 8; |
|
|
| if data.len() < offset + state_len { |
| return Err(AttentionError::InvalidFormat("State truncated".into())); |
| } |
| let state = AttentionState::from_bytes(&data[offset..offset + state_len])?; |
| offset += state_len; |
| states.push(state); |
| } |
|
|
| Ok(Self { |
| states, |
| session_id, |
| document_id, |
| }) |
| } |
| } |
|
|
| impl Default for AttentionBatch { |
| fn default() -> Self { |
| Self::new() |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_role_roundtrip() { |
| for role in [Role::System, Role::User, Role::Assistant, Role::Tool, Role::Context] { |
| let byte = role.to_byte(); |
| let restored = Role::from_byte(byte).unwrap(); |
| assert_eq!(role, restored); |
| } |
| } |
|
|
| #[test] |
| fn test_attention_state_roundtrip() { |
| let state = AttentionState::new( |
| Role::User, |
| "Hello, how are you?".to_string(), |
| vec![0.1, 0.2, 0.3, 0.4], |
| ) |
| .with_metadata("turn", "1"); |
|
|
| let bytes = state.to_bytes(); |
| let restored = AttentionState::from_bytes(&bytes).unwrap(); |
|
|
| assert_eq!(state.role, restored.role); |
| assert_eq!(state.text, restored.text); |
| assert_eq!(state.embedding, restored.embedding); |
| assert_eq!(state.metadata.get("turn"), restored.metadata.get("turn")); |
| } |
|
|
| #[test] |
| fn test_attention_state_with_kv() { |
| let kv = CompressedKV { |
| model_id: "llama-3-8b".to_string(), |
| num_layers: 32, |
| num_heads: 32, |
| head_dim: 128, |
| seq_len: 10, |
| quantization: "fp16".to_string(), |
| data: vec![1, 2, 3, 4, 5], |
| }; |
|
|
| let state = AttentionState::new( |
| Role::Assistant, |
| "I'm doing well!".to_string(), |
| vec![0.5, 0.6, 0.7, 0.8], |
| ) |
| .with_kv_cache(kv); |
|
|
| let bytes = state.to_bytes(); |
| let restored = AttentionState::from_bytes(&bytes).unwrap(); |
|
|
| assert!(restored.kv_cache.is_some()); |
| let restored_kv = restored.kv_cache.unwrap(); |
| assert_eq!(restored_kv.model_id, "llama-3-8b"); |
| assert_eq!(restored_kv.num_layers, 32); |
| assert_eq!(restored_kv.data, vec![1, 2, 3, 4, 5]); |
| } |
|
|
| #[test] |
| fn test_batch_roundtrip() { |
| let mut batch = AttentionBatch::new() |
| .with_session(Id::now()); |
|
|
| batch.add(AttentionState::new( |
| Role::User, |
| "Question 1".to_string(), |
| vec![0.1, 0.2], |
| )); |
|
|
| batch.add(AttentionState::new( |
| Role::Assistant, |
| "Answer 1".to_string(), |
| vec![0.3, 0.4], |
| )); |
|
|
| let bytes = batch.to_bytes(); |
| let restored = AttentionBatch::from_bytes(&bytes).unwrap(); |
|
|
| assert_eq!(restored.states.len(), 2); |
| assert_eq!(restored.states[0].text, "Question 1"); |
| assert_eq!(restored.states[1].text, "Answer 1"); |
| assert!(restored.session_id.is_some()); |
| } |
| } |
|
|