diff --git a/src/app.py b/src/app.py index 79f58db..12769e2 100644 --- a/src/app.py +++ b/src/app.py @@ -8,8 +8,8 @@ from src.modules.events import get_events from src.modules.factory import bind_deployment_handles from src.modules.modules import get_modules -from src.modules.rag.docker_services import OllamaService, QdrantService - +from src.services.ollama import build_ollama +from src.services.qdrant import build_qdrant def load_services_config() -> Any: config_path = Path(__file__).resolve().parents[1] / "config" / "huri.yaml" @@ -18,24 +18,6 @@ def load_services_config() -> Any: return config.get("services", {}) -def build_qdrant(config: dict) -> Any: - return QdrantService.bind( # type: ignore[attr-defined] - port=config.get("port", 6333), - image=config.get("image", "qdrant/qdrant:latest"), - storage_volume=config.get("storage_volume", "qdrant_data"), - ) - - -def build_ollama(config: dict) -> Any: - return OllamaService.options( # type: ignore[attr-defined] - num_replicas=config.get("num_replicas", 1), - ).bind( - model=config.get("model", "mistral:7b"), - image=config.get("image", "ollama/ollama:latest"), - gpu_devices=config.get("gpu_devices", False), - ) - - def build_app() -> Application: modules = get_modules() events = get_events() diff --git a/src/core/client.py b/src/core/client.py index 085a0b8..9c2e57c 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,12 +1,12 @@ import asyncio import json -import os from dataclasses import asdict -from typing import Dict, List, Optional, Type +from typing import Dict, List, Type import websockets from src.core.dataclasses.config import ClientConfig +from src.core.user_config import get_or_create_user_id from .client_senders import ClientSender, get_senders @@ -17,23 +17,13 @@ class Client: def __init__( self, config: ClientConfig, - user_id_file: str = os.path.expanduser("~/.huri_user_id"), + user_id_file: str | None = None, senders_dict: Dict[str, Type[ClientSender]] = get_senders(), ): self.config = config self.user_id_file = user_id_file self.senders_dict = senders_dict - def _load_user_id(self) -> Optional[str]: - if os.path.exists(self.user_id_file): - with open(self.user_id_file) as f: - return f.read().strip() - return None - - def _save_user_id(self, _user_id: str): - with open(self.user_id_file, "w") as f: - f.write(_user_id) - async def _receive_loop(self, ws: websockets.ClientConnection): try: while True: @@ -48,7 +38,7 @@ async def run(self): async with websockets.connect(self.config.huri_url) as ws: print("Connected to server") - self.config.user_id = self._load_user_id() + self.config.user_id = get_or_create_user_id(self.user_id_file) senders: List[ClientSender] = [ self.senders_dict[config.name](ws=ws, **config.args) @@ -60,7 +50,6 @@ async def run(self): init_msg = json.loads(await ws.recv()) if init_msg.get("type") == "session_init": user_id = init_msg["user_id"] - self._save_user_id(user_id) print(f"Session started with _user_id: {user_id}") receive_task = asyncio.create_task(self._receive_loop(ws)) diff --git a/src/core/module.py b/src/core/module.py index 0a571a8..c043b3e 100644 --- a/src/core/module.py +++ b/src/core/module.py @@ -7,6 +7,7 @@ class Module: input_type: str output_type: Optional[str] + async def process(self, _) -> Optional[Any]: raise NotImplementedError diff --git a/src/core/user_config.py b/src/core/user_config.py new file mode 100644 index 0000000..19a74ef --- /dev/null +++ b/src/core/user_config.py @@ -0,0 +1,60 @@ +import os +import platform +import uuid +from pathlib import Path + + +def get_config_dir() -> Path: + """Cross-platform config directory.""" + system = platform.system() + + if system == "Windows": + # TODO: To be tested -> also consider language-specific if needed + base = os.environ.get("APPDATA", os.path.expanduser("~/AppData/Roaming")) + elif system == "Darwin": + # TODO: To be tested -> also consider language-specific if needed + base = os.path.expanduser("~/Library/Application Support") + else: + base = os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) + + config_dir = Path(base) / "huri" + config_dir.mkdir(parents=True, exist_ok=True) + return config_dir + + +def load_user_id(path: str | None = None) -> str | None: + """Load existing _user_id, or return None if new user.""" + id_file: Path + + if path is None: + id_file = get_config_dir() / "_user_id" + else: + id_file = Path(path) + if id_file.exists(): + uid = id_file.read_text().strip() + if uid: + return uid + return None + + +def save_user_id(_user_id: str, path: str | None = None): + id_file: Path + + if path is None: + id_file = get_config_dir() / "_user_id" + else: + id_file = Path(path) + + id_file.write_text(_user_id) + if platform.system() != "Windows": + id_file.chmod(0o600) + + +def get_or_create_user_id(path: str | None = None) -> str: + """Load existing or generate new _user_id.""" + uid = load_user_id(path) + if uid: + return uid + uid = str(uuid.uuid4()) + save_user_id(uid, path) + return uid diff --git a/src/modules/rag/__init__.py b/src/modules/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index f4e4dae..a50f747 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -1,5 +1,4 @@ import argparse -import os import re import sys import uuid @@ -17,10 +16,10 @@ PointStruct, VectorParams, ) -from semantic_chunker import SemanticChunker from sentence_transformers import SentenceTransformer -USER_ID_FILE = os.path.expanduser("~/.huri_user_id") +from src.core.user_config import get_or_create_user_id +from src.modules.rag.semantic_chunker import SemanticChunker def _split_sentences(text: str) -> list[str]: @@ -89,21 +88,6 @@ def extract_text_from_pdf(pdf_path: str) -> str: sys.exit(1) -def get_user_id(provided_id: str | None = None) -> str: - if provided_id: - return provided_id - if os.path.exists(USER_ID_FILE): - with open(USER_ID_FILE) as f: - uid = f.read().strip() - if uid: - return uid - new_id = str(uuid.uuid4()) - with open(USER_ID_FILE, "w") as f: - f.write(new_id) - print(f"Generated new user_id: {new_id}") - return new_id - - def ensure_collection(client: QdrantClient, collection: str, vector_size: int): collections = [c.name for c in client.get_collections().collections] if collection not in collections: @@ -145,7 +129,6 @@ def ingest_chunks( ) if points: - # Upsert in batches of 100 batch_size = 100 for i in range(0, len(points), batch_size): batch = points[i : i + batch_size] @@ -403,7 +386,7 @@ def main(): args = parser.parse_args() - _user_id = get_user_id(args._user_id) + _user_id = get_or_create_user_id() print(f"User: {_user_id}") client = QdrantClient(url=args.qdrant_url) @@ -440,7 +423,7 @@ def main(): # Ingest a text file python ingestion.py text notes.txt story.md - # Specify a user ID (otherwise reads from ~/.huri_user_id) + # Specify a user ID (otherwise it will be auto-generated and saved) python ingestion.py --user-id "abc-123" pdf report.pdf # Use a different collection diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 6b9744d..4ffa898 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -1,3 +1,4 @@ +import json from dataclasses import dataclass, field from typing import Any, Optional @@ -21,6 +22,7 @@ class RAGQuery: _user_id: str question: str preferences: dict = field(default_factory=dict) + history: list[dict] | None = None # preferences can include: language, tone, # response_format, max_length, system_prompt, extra_instructions, etc. @@ -36,6 +38,7 @@ class RAGHandle: collection/data in the vector DB, runs embed -> search -> LLM. """ + def __init__( self, ollama_handle=None, @@ -66,8 +69,8 @@ def __init__( self._qdrant_url = qdrant_url self._qdrant: QdrantClient | None = None + async def _get_qdrant(self): - """Connect to Qdrant on first use. Solves the async-in-init problem.""" if self._qdrant is None: if self.qdrant_handle: self._qdrant_url = await self.qdrant_handle.get_url.remote() @@ -75,25 +78,19 @@ async def _get_qdrant(self): print(f"[RAGHandle] Connected to Qdrant at {self._qdrant_url}") return self._qdrant - def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: - """ - Given a _user_id, decide which collection to search - and which filters to apply. - Options (pick what fits your data model): - A) One collection per user: collection = f"user_{_user_id}" - B) Shared collection, filter by _user_id in payload - C) Lookup in a DB to find the user's config - """ + def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: collection = self.default_collection filters = {"_user_id": _user_id} return collection, filters + def _embed(self, text) -> list[float] | Any: return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() + def _search( self, qdrant, @@ -110,8 +107,9 @@ def _search( ] qdrant_filter = Filter(must=conditions) + doc_results = [] try: - results = qdrant.query_points( + doc_results = qdrant.query_points( collection_name=collection, query=query_vector, query_filter=qdrant_filter, @@ -119,27 +117,36 @@ def _search( score_threshold=self.score_threshold, ).points except Exception: - results = [] + pass + return [ { "text": point.payload.get("text", ""), "score": point.score, "metadata": {k: v for k, v in point.payload.items() if k != "text"}, } - for point in results + for point in doc_results ] + def _build_prompt( self, question: str, chunks: list[dict], preferences: dict, + history=None, ) -> tuple[str, str]: - parts = [ - "You are a robot speaking to a user. Answer based on the provided context.", - "If the context is insufficient, say so clearly.", - ] + parts = [] + + if history: + lines = [f"{m['role']}: {m['content']}" for m in history] + parts.append("[Recent conversation]\n" + "\n".join(lines)) + + parts.append( + "You are a robot speaking to a user. Answer based on the provided context." + + " If the context is insufficient, say so clearly.", + ) if preferences.get("language"): parts.append(f"Always respond in {preferences['language']}.") if preferences.get("tone"): @@ -153,11 +160,7 @@ def _build_prompt( system_prompt = " ".join(parts) if not chunks: - user_prompt = ( - "No relevant context was found.\n\n" - f"Question: {question}\n\n" - "Answer based on general knowledge." - ) + user_prompt = f"Question: {question}\n\n" else: context_parts = [] for i, chunk in enumerate(chunks, 1): @@ -174,6 +177,7 @@ def _build_prompt( return system_prompt, user_prompt + async def _llm_generate( self, system_prompt: str, @@ -195,17 +199,38 @@ async def _llm_generate( ) elif self.llm_provider == "ollama": return await self._call_ollama(messages, max_tokens) - elif self.llm_provider == "api": return await self._call_openai_compatible( f"{self.llm_url}/v1/chat/completions", - messages, - max_tokens, - self.llm_api_key, + messages, max_tokens, self.llm_api_key, ) else: raise ValueError(f"Unknown llm_provider: {self.llm_provider}") + + async def _llm_generate_stream( + self, system_prompt: str, user_prompt: str, preferences: dict + ): + """Yields tokens as they arrive.""" + max_tokens = preferences.get("max_length", 1024) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + if self.ollama_handle: + async for token in self.ollama_handle.generate_stream.options( + stream=True + ).remote(messages, max_tokens): + yield token + elif self.llm_provider == "ollama": + async for token in self._call_ollama_stream(messages, max_tokens): + yield token + else: + result = await self._llm_generate(system_prompt, user_prompt, preferences) + yield result + + async def _call_openai_compatible( self, url: str, messages: list, max_tokens: int, api_key: str = "" ) -> Any: @@ -226,7 +251,8 @@ async def _call_openai_compatible( resp.raise_for_status() return resp.json()["choices"][0]["message"]["content"] - async def _call_ollama(self, messages: list, max_tokens: int) -> Any: + + async def _call_ollama(self, messages: list, max_tokens: int) -> str: async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.post( f"{self.llm_url}/api/chat", @@ -240,30 +266,41 @@ async def _call_ollama(self, messages: list, max_tokens: int) -> Any: resp.raise_for_status() return resp.json()["message"]["content"] - async def process(self, query: RAGQuery) -> RAGResult: - """ - Main entry point. Called by the RAG module. - Uses _user_id to determine which collection / filters to use. - """ - print(f"[RAG] Question: {query.question}") + async def _call_ollama_stream(self, messages: list, max_tokens: int): + async with httpx.AsyncClient(timeout=120.0) as client: + async with client.stream( + "POST", + f"{self.llm_url}/api/chat", + json={ + "model": self.llm_model, + "messages": messages, + "stream": True, + "options": {"num_predict": max_tokens, "temperature": 0.1}, + }, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + chunk = json.loads(line) + token = chunk.get("message", {}).get("content", "") + if token: + yield token + if chunk.get("done", False): + return + + async def process(self, query: RAGQuery) -> RAGResult: qdrant = await self._get_qdrant() - collection, filters = self._resolve_user_context(query._user_id) query_vector = self._embed(query.question) chunks = self._search(qdrant, query_vector, collection, filters) - print(f"[RAG] Found {len(chunks)} chunks") - for c in chunks: - print(f" - score: {c['score']:.2f} | {c['text'][:100]}...") - system_prompt, user_prompt = self._build_prompt( - query.question, chunks, query.preferences + query.question, chunks, query.preferences, query.history ) - print(f"[RAG] System prompt: {system_prompt[:200]}...") answer = await self._llm_generate(system_prompt, user_prompt, query.preferences) - print(f"[RAG] Answer: {answer}") return RAGResult( answer=answer, @@ -274,6 +311,21 @@ async def process(self, query: RAGQuery) -> RAGResult: ) + async def process_stream(self, query: RAGQuery): + qdrant = await self._get_qdrant() + collection, filters = self._resolve_user_context(query._user_id) + query_vector = self._embed(query.question) + chunks = self._search(qdrant, query_vector, collection, filters) + + system_prompt, user_prompt = self._build_prompt( + query.question, chunks, query.preferences, query.history + ) + async for token in self._llm_generate_stream( + system_prompt, user_prompt, query.preferences + ): + yield token + + class RAG(ModuleWithHandle, ModuleWithId): _handle_cls = RAGHandle input_type = "question" @@ -288,6 +340,8 @@ def __init__( response_format="paragraph", max_length=1024, extra_instructions="", + max_history=10, + stream=True, **kwargs, ): super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) @@ -299,23 +353,47 @@ def __init__( "max_length": max_length, "extra_instructions": extra_instructions, } + self.stream = stream + self.history: list[dict] = [] + self.max_history = max_history + + async def process(self, data: Sentence): + query = self._build_query(data.text) + if self._handle is None: + return + + if True: # TODO: to change later self.stream + async for token in self._stream_answer(data.text, query): + yield token + else: + result = await self._handle.process.remote(query) + self.history.append({"role": "user", "content": data.text}) + self.history.append({"role": "assistant", "content": result.answer}) + yield result + + + async def _stream_answer(self, question_text, query): + full_answer = [] + async for token in self._handle.process_stream.options( + stream=True + ).remote(query): + full_answer.append(token) + yield token + + print(f"[RAG] Full answer for question '{question_text}': {''.join(full_answer)}") - async def process(self, data: Sentence) -> Optional[RAGResult]: - """ - Called when a "question" event arrives through the event bus. - Packages _user_id + question, sends to the stateless RAGHandle. - """ - question_text = data.text + self.history.append({"role": "user", "content": question_text}) + self.history.append({"role": "assistant", "content": "".join(full_answer)}) - query = RAGQuery( - _user_id=self._user_id if self._user_id else "anonymous", + + def _build_query(self, question_text: str) -> RAGQuery: + return RAGQuery( + _user_id=self._user_id or "anonymous", question=question_text, preferences=self.preferences, + history=( + self.history + if len(self.history) <= self.max_history + else self.history[-self.max_history:] + ), ) - - result: RAGResult = await self._handle.process.remote(query) - return result - - def update_preferences(self, new_preferences: dict): - """Client can update preferences mid-session via the event bus.""" - self.preferences.update(new_preferences) diff --git a/src/modules/rag/docker_services.py b/src/services/docker_services.py similarity index 77% rename from src/modules/rag/docker_services.py rename to src/services/docker_services.py index cc9f4b5..dcf25ea 100644 --- a/src/modules/rag/docker_services.py +++ b/src/services/docker_services.py @@ -1,6 +1,7 @@ import socket import subprocess import time +import json from typing import Any import httpx @@ -119,17 +120,68 @@ def __init__( print(f"[OllamaService] Ready! \ container='{self.container_name}', port={self.port}, model='{model}'") + async def generate( self, messages: list, max_tokens: int = 1024, temperature: float = 0.1, - ) -> Any: + stream: bool = False, + ): """ - Send messages to Ollama and return the response. - This is what RAGHandle calls to get LLM answers. + Dispatcher: + stream=True -> delegates to generate_stream (yields tokens) + stream=False -> delegates to generate_paragraphs (yields paragraphs) """ - async with httpx.AsyncClient(timeout=60.0) as client: + if stream: + async for token in self.generate_stream(messages, max_tokens, temperature): + yield token + else: + async for paragraph in self.generate_paragraphs(messages, max_tokens, temperature): + yield paragraph + + + async def generate_stream( + self, + messages: list, + max_tokens: int = 1024, + temperature: float = 0.1, + ): + """Stream=True. Yields individual tokens.""" + async with httpx.AsyncClient(timeout=120.0) as client: + async with client.stream( + "POST", + f"{self.base_url}/api/chat", + json={ + "model": self.model, + "messages": messages, + "stream": True, + "options": { + "num_predict": max_tokens, + "temperature": temperature, + }, + }, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + chunk = json.loads(line) + token = chunk.get("message", {}).get("content", "") + if token: + yield token + if chunk.get("done", False): + return + + + async def generate_paragraphs( + self, + messages: list, + max_tokens: int = 1024, + temperature: float = 0.1, + ): + """Stream=False. Gets full response, then yields paragraphs.""" + async with httpx.AsyncClient(timeout=120.0) as client: resp = await client.post( f"{self.base_url}/api/chat", json={ @@ -143,7 +195,13 @@ async def generate( }, ) resp.raise_for_status() - return resp.json()["message"]["content"] + full_text = resp.json()["message"]["content"] + + for paragraph in full_text.split("\n\n"): + paragraph = paragraph.strip() + if paragraph: + yield paragraph + async def health(self) -> dict: """Check if this Ollama instance is alive.""" diff --git a/src/services/ollama.py b/src/services/ollama.py new file mode 100644 index 0000000..3278ae6 --- /dev/null +++ b/src/services/ollama.py @@ -0,0 +1,12 @@ +from typing import Any + +from src.services.docker_services import OllamaService + +def build_ollama(config: dict) -> Any: + return OllamaService.options( # type: ignore[attr-defined] + num_replicas=config.get("num_replicas", 1), + ).bind( + model=config.get("model", "mistral:7b"), + image=config.get("image", "ollama/ollama:latest"), + gpu_devices=config.get("gpu_devices", False), + ) diff --git a/src/services/qdrant.py b/src/services/qdrant.py new file mode 100644 index 0000000..3103760 --- /dev/null +++ b/src/services/qdrant.py @@ -0,0 +1,11 @@ +from typing import Any + +from src.services.docker_services import QdrantService + +def build_qdrant(config: dict) -> Any: + return QdrantService.bind( # type: ignore[attr-defined] + port=config.get("port", 6333), + image=config.get("image", "qdrant/qdrant:latest"), + storage_volume=config.get("storage_volume", "qdrant_data"), + ) +