Skip to content

Commit

Permalink
Implement chat threads
Browse files Browse the repository at this point in the history
  • Loading branch information
Vidminas committed Feb 3, 2024
1 parent 7ec2f2c commit dc4363c
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 56 deletions.
119 changes: 81 additions & 38 deletions chat_app/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
from urllib.parse import unquote

import requests
import streamlit as st
from st_pages import Page, show_pages, hide_pages
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.schema import messages_to_dict
from chat_app.solid_message_history import SolidChatMessageHistory
from chat_app.solid_pod_utils import SolidPodUtils

hostname = os.environ.get("WEBSITE_HOSTNAME")
if hostname is not None:
Expand All @@ -15,43 +18,42 @@

def setup_login_sidebar():
from chat_app.solid_oidc_button import SolidOidcComponent
from solid_oidc_client import SolidAuthSession

# Default IDP list from https://solidproject.org/users/get-a-pod
solid_server_url = st.sidebar.selectbox(
"Solid ID Provider",
(
"https://solidcommunity.net/",
"https://login.inrupt.com/",
"https://solidweb.org/",
"https://trinpod.us/",
"https://get.use.id/",
"https://solidweb.me/",
"https://datapod.igrant.io/",
"https://solid.redpencil.io/",
"https://teamid.live/",
"Other...",
),
disabled="solid_token" in st.session_state,
)
if solid_server_url == "Other...":
solid_server_url = st.sidebar.text_input(
"Solid Server URL",
"https://solidpod.azurewebsites.net",

if "solid_token" not in st.session_state:
# Default IDP list from https://solidproject.org/users/get-a-pod
solid_server_url = st.sidebar.selectbox(
"Solid ID Provider",
(
"https://solidcommunity.net/",
"https://login.inrupt.com/",
"https://solidweb.org/",
"https://trinpod.us/",
"https://get.use.id/",
"https://solidweb.me/",
"https://datapod.igrant.io/",
"https://solid.redpencil.io/",
"https://teamid.live/",
"Other...",
),
disabled="solid_token" in st.session_state,
)
if solid_server_url == "Other...":
solid_server_url = st.sidebar.text_input(
"Solid Server URL",
"https://solidpod.azurewebsites.net",
disabled="solid_token" in st.session_state,
)

if "solid_idps" not in st.session_state:
st.session_state["solid_idps"] = {}
if "solid_idps" not in st.session_state:
st.session_state["solid_idps"] = {}

if solid_server_url not in st.session_state["solid_idps"]:
st.session_state["solid_idps"][solid_server_url] = SolidOidcComponent(
solid_server_url
)
if solid_server_url not in st.session_state["solid_idps"]:
st.session_state["solid_idps"][solid_server_url] = SolidOidcComponent(
solid_server_url
)

solid_client = st.session_state["solid_idps"][solid_server_url]
solid_client = st.session_state["solid_idps"][solid_server_url]

if "solid_token" not in st.session_state:
with st.sidebar:
result = solid_client.authorize_button(
name="Login with Solid",
Expand All @@ -67,8 +69,8 @@ def setup_login_sidebar():
st.session_state["solid_token"] = result["token"]
st.rerun()
else:
solid_auth = SolidAuthSession.deserialize(st.session_state["solid_token"])
st.sidebar.markdown(f"Logged in as <{solid_auth.get_web_id()}>")
solid_utils = SolidPodUtils(st.session_state["solid_token"])
st.sidebar.markdown(f"Logged in as <{solid_utils.webid}>")

def logout():
# TODO: this should also revoke the token, but not implemented yet
Expand All @@ -78,11 +80,53 @@ def logout():

st.sidebar.button("Log Out", on_click=logout)

threads = solid_utils.list_container_items(solid_utils.workspace_uri)
if "msg_history" not in st.session_state:
st.session_state["msg_history"] = SolidChatMessageHistory(
st.session_state["solid_token"],
thread_uri=threads[0] if len(threads) else None,
)

def switch_active_thread(new_thread_uri):
if new_thread_uri != st.session_state["msg_history"].thread_uri:
st.session_state["msg_history"] = SolidChatMessageHistory(
st.session_state["solid_token"], new_thread_uri
)

st.sidebar.divider()
st.sidebar.caption("Chats")

def init_messages(history: BaseChatMessageHistory) -> None:
clear_button = st.sidebar.button("Clear Conversation", key="clear")
if clear_button or len(history.messages) == 0:
history.clear()
for thread in threads:
thread_label = unquote(
thread.removeprefix(solid_utils.workspace_uri).removesuffix(".ttl")
)
with st.sidebar:
col1, col2 = st.columns([5, 1])
col1.button(
label=thread_label,
key=thread,
on_click=switch_active_thread,
args=(thread,),
type="primary"
if thread == st.session_state["msg_history"].thread_uri
else "secondary",
use_container_width=True,
)
col2.button(
label=":wastebasket:",
key="del_" + thread,
help="Delete " + thread_label,
on_click=st.session_state["msg_history"].clear,
)
if not len(threads):
st.sidebar.write("Nothing here yet... Start typing on the right ->")
st.sidebar.button(
label="Start new conversation",
on_click=switch_active_thread,
args=(None,),
use_container_width=True,
)
st.sidebar.divider()


def print_state_messages(history: BaseChatMessageHistory):
Expand Down Expand Up @@ -115,7 +159,6 @@ def main():
st.session_state["solid_token"]
)
history = st.session_state["msg_history"]
init_messages(history)
print_state_messages(history)

if "llm_options" not in st.session_state:
Expand Down
32 changes: 24 additions & 8 deletions chat_app/solid_message_history.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from urllib.parse import quote

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
from rdflib import Graph, BNode, URIRef, Literal, RDF, PROF, XSD
Expand All @@ -14,18 +16,21 @@ class SolidChatMessageHistory(BaseChatMessageHistory):
solid_token: A serialized SolidAuthSession
"""

def __init__(self, solid_token):
def __init__(self, solid_token, thread_uri=None):
self.graph = Graph()
self.solid_utils = SolidPodUtils(solid_token)
self.genpod_messages_uri = self.solid_utils.workspace_uri + "genpod.ttl"
self.thread_uri = thread_uri

@property
def messages(self) -> list[BaseMessage]:
"""Retrieve the current list of messages"""
if not self.solid_utils.is_solid_item_available(self.genpod_messages_uri):
self.solid_utils.create_solid_item(self.genpod_messages_uri)
if self.thread_uri is None:
return []

if not self.solid_utils.is_solid_item_available(self.thread_uri):
self.solid_utils.create_solid_item(self.thread_uri)

self.graph = self.solid_utils.read_solid_item(self.genpod_messages_uri)
self.graph = self.solid_utils.read_solid_item(self.thread_uri)
list_node = self.graph.value(predicate=RDF.type, object=RDF.List)
if list_node is None:
return []
Expand All @@ -44,6 +49,16 @@ def messages(self) -> list[BaseMessage]:

def add_message(self, message: BaseMessage) -> None:
"""Add a message to the session memory"""
if self.thread_uri is None:
thread_name = quote(" ".join(message.content.split(maxsplit=3)[:3]), safe="")
candidate_uri = self.solid_utils.workspace_uri + thread_name + ".ttl"
i = 2
while self.solid_utils.is_solid_item_available(candidate_uri):
candidate_uri = self.solid_utils.workspace_uri + thread_name + f" #{i}.ttl"
i += 1
self.thread_uri = candidate_uri
self.solid_utils.create_solid_item(self.thread_uri)

# https://solidproject.org/TR/protocol#n3-patch seems to be broken with Community Solid Server
# https://www.w3.org/TR/sparql11-update/ works
update_graph = Graph()
Expand All @@ -59,7 +74,7 @@ def add_message(self, message: BaseMessage) -> None:

list_node = self.graph.value(predicate=RDF.type, object=RDF.List)
if list_node is None:
msgs_node = URIRef(f"{self.genpod_messages_uri}#messages")
msgs_node = URIRef(f"{self.thread_uri}#messages")
update_graph.add((msgs_node, RDF.type, RDF.List))

msgs = Collection(update_graph, msgs_node)
Expand Down Expand Up @@ -92,11 +107,12 @@ def add_message(self, message: BaseMessage) -> None:
"""

# Update remote copy
self.solid_utils.update_solid_item(self.genpod_messages_uri, sparql)
self.solid_utils.update_solid_item(self.thread_uri, sparql)
# Update local copy
self.graph.update(sparql)

def clear(self) -> None:
"""Clear session memory"""
self.solid_utils.delete_solid_item(self.genpod_messages_uri)
self.solid_utils.delete_solid_item(self.thread_uri)
self.thread_uri = None
self.graph = Graph()
42 changes: 32 additions & 10 deletions chat_app/solid_pod_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
solid_ns = Namespace("http://www.w3.org/ns/solid/terms#")
ldp_ns = Namespace("http://www.w3.org/ns/ldp#")

CONFIG_RESOURCE_NAME = "_config.ttl"


def get_item_name(url: str) -> str:
if url[-1] == "/":
Expand All @@ -29,7 +31,7 @@ class SolidPodUtils:
solid_token: A serialized SolidAuthSession
"""

def __init__(self, solid_token):
def __init__(self, solid_token: str):
self.solid_auth = SolidAuthSession.deserialize(solid_token)
self.session = requests.Session()

Expand Down Expand Up @@ -101,10 +103,15 @@ def __init__(self, solid_token):
f"}}"
)
self.update_solid_item(private_index_uri, sparql)

if not self.is_solid_item_available(self.workspace_uri):
self.create_solid_item(self.workspace_uri)

def is_solid_item_available(self, url) -> bool:
self.config_uri = self.workspace_uri + CONFIG_RESOURCE_NAME
if not self.is_solid_item_available(self.config_uri):
self.create_solid_item(self.config_uri)

def is_solid_item_available(self, url: str) -> bool:
try:
res = self.session.head(
url,
Expand All @@ -115,7 +122,7 @@ def is_solid_item_available(self, url) -> bool:
except requests.exceptions.ConnectionError:
return False

def create_solid_item(self, uri: str) -> bool:
def create_solid_item(self, uri: str) -> None:
res = self.session.put(
uri,
data=None,
Expand All @@ -130,9 +137,10 @@ def create_solid_item(self, uri: str) -> bool:
**self.solid_auth.get_auth_headers(uri, "PUT"),
},
)
return res.ok

def read_solid_item(self, uri) -> Graph:
if not res.ok:
raise RuntimeError("Error creating item " + uri + ": " + res.text)

def read_solid_item(self, uri: str) -> Graph:
content = Graph()
content.bind("solid", solid_ns)
content.bind("pim", pim_ns)
Expand All @@ -144,10 +152,22 @@ def read_solid_item(self, uri) -> Graph:
**self.solid_auth.get_auth_headers(uri, "GET"),
},
)
if not res.ok:
raise RuntimeError("Error reading item " + uri + ": " + res.text)
content.parse(data=res.text, publicID=uri)
return content

def update_solid_item(self, uri: str, sparql: str):
def list_container_items(
self, uri: str, ignore_resource_names=[CONFIG_RESOURCE_NAME]
) -> list[URIRef]:
container_graph = self.read_solid_item(uri)
return [
item
for item in container_graph.objects(URIRef(uri), ldp_ns.contains)
if item.removeprefix(uri) not in ignore_resource_names
]

def update_solid_item(self, uri: str, sparql: str) -> None:
res = self.session.patch(
url=uri,
data=sparql.encode("utf-8"),
Expand All @@ -156,11 +176,13 @@ def update_solid_item(self, uri: str, sparql: str):
**self.solid_auth.get_auth_headers(uri, "PATCH"),
},
)
return res.ok
if not res.ok:
raise RuntimeError("Error updating item " + uri + ": " + res.text)

def delete_solid_item(self, uri: str):
def delete_solid_item(self, uri: str) -> None:
res = self.session.delete(
uri,
headers=self.solid_auth.get_auth_headers(uri, "DELETE"),
)
return res.ok
if not res.ok:
raise RuntimeError("Error deleting item " + uri + ": " + res.text)

0 comments on commit dc4363c

Please sign in to comment.