Skip to content
Open
16 changes: 10 additions & 6 deletions config/client_aux.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
huri_url: ws://localhost:8000/session

topic_list: ["transcript", "question"]
topic_list: [question]

sample_rate: 16000
frame_duration: 0.030
senders:
audio:
name: audio
args:
sample_rate: 16000
frame_duration: 0.030

modules:
mic:
name: mic
args:
vad_agressiveness: 3
silence_duration: 1.5
block_duration: ${frame_duration}
block_duration: ${inputs.audio.args.frame_duration}
logging: INFO
stt:
name: stt
args:
language: "fr"
block_duration: ${frame_duration}
language: "en"
block_duration: ${inputs.audio.args.frame_duration}
logging: INFO
tag:
name: tag
Expand Down
22 changes: 14 additions & 8 deletions config/client_aux2.yaml
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
huri_url: ws://localhost:8000/session

topic_list: ["transcript", "question", "rag_response"]
sample_rate: 16000
frame_duration: 0.030
topic_list: [transcript, question, rag_response]

senders:
audio:
name: audio
args:
sample_rate: 16000
frame_duration: 0.030

modules:
mic:
name: mic
args:
vad_agressiveness: 3
silence_duration: 1.5
block_duration: ${frame_duration}
block_duration: ${senders.audio.args.frame_duration}
stt:
name: stt
args:
language: "en"
block_duration: ${frame_duration}
language: en
block_duration: ${senders.audio.args.frame_duration}
logging: INFO
tag:
name: tag
logging: INFO
rag:
name: rag
args:
language: "en"
tone: "formal"
language: en
tone: formal
25 changes: 25 additions & 0 deletions config/client_auxio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
huri_url: ws://localhost:8000/session

topic_list: [question]

senders:
text:
name: text

modules:
mic:
name: mic
args:
vad_agressiveness: 3
silence_duration: 1.5
block_duration: ${senders.audio.args.frame_duration}
logging: INFO
stt:
name: stt
args:
language: en
block_duration: ${senders.audio.args.frame_duration}
logging: INFO
tag:
name: tag
logging: INFO
16 changes: 13 additions & 3 deletions config/client_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
huri_url: ws://localhost:8000/session

# List of event topic the client will receive
topic_list: ["topic1", "topic2"]
topic_list: [topic1, topic2]

# Define module custom args
# Define senders to be used and their custom args
senders:
# sender tag can be anything
example:
# sender name must be in the list of available ClientSender in Client instance (src.client_sender:get_senders)
name: my_sender
# if my_sender init with "model", "sample_rate" and "refresh_rate" params, they can be customized here
args:
refresh_rate: infinite

# Define module to be used and their custom args
modules:
# module tag can be anything
example:
# module name must be in the list of available module in HuRI's instance (src.modules.modules:get_modules)
name: my_module
# if my_module init with "model", "sample_rate" and "hello" params, they can be customized here
args:
hello: "world"
hello: world
15 changes: 15 additions & 0 deletions config/client_text.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
huri_url: ws://localhost:8000/session

topic_list: [question, rag_response]

senders:
text:
name: text

modules:
rag:
name: rag
args:
language: en
tone: formal
logging: INFO
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ webrtcvad
faster-whisper
qdrant-client
sentence-transformers
pypdf
semantic_chunker


# client
sounddevice
websockets
omegaconf

prompt-toolkit
5 changes: 4 additions & 1 deletion src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ray.serve import Application

from src.core.huri import HuRI
from src.modules.events import get_events
from src.modules.factory import bind_deployment_handles
from src.modules.modules import get_modules
from src.modules.rag.docker_services import OllamaService, QdrantService
Expand Down Expand Up @@ -37,13 +38,15 @@ def build_ollama(config: dict) -> Any:

def build_app() -> Application:
modules = get_modules()
events = get_events()

services_config = load_services_config()

qdrant = build_qdrant(services_config.get("qdrant", {}))
ollama = build_ollama(services_config.get("ollama", {}))

handles = bind_deployment_handles(modules, ollama=ollama, qdrant=qdrant)
app: Application = HuRI.bind(modules, handles) # type: ignore[attr-defined]
app: Application = HuRI.bind(modules, handles, events) # type: ignore[attr-defined]
return app


Expand Down
72 changes: 7 additions & 65 deletions src/client.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,12 @@
import argparse
import asyncio
import json
import os
from dataclasses import asdict
from typing import Dict

import numpy as np
import sounddevice as sd
import websockets
from omegaconf import OmegaConf

from src.core.client import Client
from src.core.dataclasses.config import ClientConfig

USER_ID_FILE = os.path.expanduser("~/.huri_user_id")


def load_user_id() -> str | None:
if os.path.exists(USER_ID_FILE):
with open(USER_ID_FILE) as f:
return f.read().strip()
return None


def save_user_id(_user_id: str):
with open(USER_ID_FILE, "w") as f:
f.write(_user_id)


def load_client_config(path: str) -> ClientConfig:
with open(path) as f:
Expand All @@ -38,7 +19,7 @@ def load_client_config(path: str) -> ClientConfig:
return ClientConfig.from_dict(raw_resolved)


async def stream_audio():
async def launch_client():
parser = argparse.ArgumentParser(description="Client config")
parser.add_argument(
"--config",
Expand All @@ -49,50 +30,11 @@ async def stream_audio():
args = parser.parse_args()
config = load_client_config(args.config)

FRAME_SIZE = int(config.sample_rate * config.frame_duration)
async with websockets.connect(config.huri_url) as ws:
print("Connected to server")

payload = asdict(config)
_user_id = load_user_id()
if _user_id:
payload["_user_id"] = _user_id
print(f"Reconnecting with _user_id: {_user_id}")

await ws.send(json.dumps(payload))

init_msg = json.loads(await ws.recv())
if init_msg.get("type") == "session_init":
_user_id = init_msg["_user_id"]
save_user_id(_user_id)
print(f"Session started with _user_id: {_user_id}")

async def receive(ws: websockets.ClientConnection):
while True:
text = await ws.recv()
print("received:", text)

async def send(ws: websockets.ClientConnection):
loop = asyncio.get_running_loop()

queue: asyncio.Queue = asyncio.Queue()

def callback(indata: np.ndarray, frames, time, status):
loop.call_soon_threadsafe(queue.put_nowait, indata.copy())

with sd.InputStream(
samplerate=config.sample_rate,
channels=1,
dtype="int16",
callback=callback,
blocksize=FRAME_SIZE,
):
while True:
chunk = await queue.get()
await ws.send(chunk.tobytes())

await asyncio.gather(receive(ws), send(ws))
await Client(config=config).run()


if __name__ == "__main__":
asyncio.run(stream_audio())
try:
asyncio.run(launch_client())
except KeyboardInterrupt:
pass
71 changes: 71 additions & 0 deletions src/core/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
import json
import os
from dataclasses import asdict
from typing import Dict, List, Optional, Type

import websockets

from src.core.dataclasses.config import ClientConfig

from .client_senders import ClientSender, get_senders


class Client:
"""Client is init with a Config, and connects to HuRI using websockets"""

def __init__(
self,
config: ClientConfig,
user_id_file: str = os.path.expanduser("~/.huri_user_id"),
senders_dict: Dict[str, Type[ClientSender]] = get_senders(),
):
self.config = config
self.user_id_file = user_id_file
self.senders_dict = senders_dict

def _load_user_id(self) -> Optional[str]:
if os.path.exists(self.user_id_file):
with open(self.user_id_file) as f:
return f.read().strip()
return None

def _save_user_id(self, _user_id: str):
with open(self.user_id_file, "w") as f:
f.write(_user_id)

async def _receive_loop(self, ws: websockets.ClientConnection):
try:
while True:
text = await ws.recv()
print("<<", text)
await asyncio.sleep(0.1)

except (asyncio.CancelledError, websockets.ConnectionClosedOK):
pass

async def run(self):
async with websockets.connect(self.config.huri_url) as ws:
print("Connected to server")

self.config.user_id = self._load_user_id()

senders: List[ClientSender] = [
self.senders_dict[config.name](ws=ws, **config.args)
for config in self.config.senders.values()
]

await ws.send(json.dumps(asdict(self.config)))

init_msg = json.loads(await ws.recv())
if init_msg.get("type") == "session_init":
user_id = init_msg["user_id"]
self._save_user_id(user_id)
print(f"Session started with _user_id: {user_id}")

receive_task = asyncio.create_task(self._receive_loop(ws))
await asyncio.gather(
*(sender.input_loop() for sender in senders),
)

receive_task.cancel()
Loading