Skip to content

Commit

Permalink
feat: Add WebID attestation for requests to llm_service
Browse files Browse the repository at this point in the history
This version uses Header parameter for retrieving header.
But `Optional` does not seem to be respected correctly.
  • Loading branch information
renyuneyun committed Feb 4, 2024
1 parent 6b3c2e7 commit 300f379
Show file tree
Hide file tree
Showing 7 changed files with 386 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

node_modules/
11 changes: 10 additions & 1 deletion chat_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain_core.chat_history import BaseChatMessageHistory
from langchain.schema import messages_to_dict
from chat_app.solid_message_history import SolidChatMessageHistory
from solid_oidc_client import SolidAuthSession

hostname = os.environ.get("WEBSITE_HOSTNAME")
if hostname is not None:
Expand Down Expand Up @@ -96,6 +97,11 @@ def print_state_messages(history: BaseChatMessageHistory):
st.markdown(message.content)


def get_auth_headers(st, url, method):
solid_auth = SolidAuthSession.deserialize(st.session_state["solid_token"])
return solid_auth.get_auth_headers(url, method)


def main():
st.set_page_config(page_title="Social Gen Pod", page_icon="🐢")
show_pages([
Expand Down Expand Up @@ -134,12 +140,15 @@ def main():
history.add_user_message(prompt)

with st.spinner("LLM is thinking...."):
url = "http://localhost:5000/completions/"
auth_headers = get_auth_headers(st, url, 'POST')
response = requests.post(
"http://localhost:5000/completions/",
url,
json={
"model": selected_llm,
"messages": messages_to_dict(history.messages),
},
headers=auth_headers,
)
st.session_state["input_disabled"] = False
if not response.ok:
Expand Down
55 changes: 55 additions & 0 deletions llm_service/attest.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
const createSolidTokenVerifier =
require("@solid/access-token-verifier").createSolidTokenVerifier;

/**
* Check whether the request belongs to a / the corresponding WebID.
* @param {string} authorizationHeader The `authorization` header
* @param {string} dpopHeader The `DPoP` header
* @param {string} requestMethod The HTTP method for the request
* @param {string} requestURL The URL of the request
* @param {string|undefined} claimedWebid What WebID the client claims to be (can be `undefined`)
* @returns {boolean|string} If `claimedWebid` is not empty, return whether the claimed WebID matches the real WebID in the credentials; otherwise, return the real WebID.
*/
async function attestWebidPossession(
authorizationHeader,
dpopHeader,
requestMethod,
requestURL,
claimedWebid
) {
const solidOidcAccessTokenVerifier = createSolidTokenVerifier();

try {
const { client_id: clientId, webid: webId } =
await solidOidcAccessTokenVerifier(authorizationHeader, {
header: dpopHeader,
method: requestMethod,
url: requestURL,
});

if (!claimedWebid) {
return webId;
}

return webId == claimedWebid;
} catch (error) {
const message = `Error verifying Access Token via WebID: ${error.message}`;
throw new Error(message);
}
}

// module.exports = {
// attestWebidPossession,
// };

async function main() {
const res = await attestWebidPossession(...process.argv.slice(2));

if (res) {
process.exit(0);
} else {
process.exit(1);
}
}

main();
43 changes: 41 additions & 2 deletions llm_service/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fastapi import FastAPI
from fastapi import FastAPI, Depends, Header, Request, HTTPException
from pydantic import BaseModel
import uvicorn
from typing import Optional
from langchain.schema import messages_from_dict

from .chains import make_conversation_chain
from .config import get_config
from .solid_utils import attessPossession

app = FastAPI()
config = get_config()
Expand All @@ -15,6 +17,40 @@ class ChatCompletionRequestData(BaseModel):
messages: list[dict]


def as_header(cls):
"""decorator for pydantic model
replaces the Signature of the parameters of the pydantic model with `Header`
See https://github.com/tiangolo/fastapi/issues/2915
"""
cls.__signature__ = cls.__signature__.replace(
parameters=[
arg.replace(
default=Header(...) if arg.default is arg.empty else Header(arg.default)
)
for arg in cls.__signature__.parameters.values()
]
)
return cls


@as_header
class WebIdDPoPInfoHeader(BaseModel):
authorization: str
dpop: str
x_forwarded_host: Optional[str]
x_forwarded_protocol: Optional[str]
webid: Optional[str]


def checkIdentity(request: Request, hdrs: WebIdDPoPInfoHeader):
method = 'POST'
host = hdrs.x_forwarded_host or request.url.hostname # Use X-Forwarded-For in case there is a reverse proxy in-between the client and the server
protocol = hdrs.x_forwarded_protocol or request.url.scheme # Same as above
path_prefix = '/' # Needed if deployed to a (sub)path instead of root of the hostname
request_url = f"{protocol}://{host}{path_prefix}{request.url.path}"
return attessPossession(hdrs.authorization, hdrs.dpop, method, request_url, hdrs.webid)


@app.get("/")
def read_root():
return {"Hello": "World"}
Expand Down Expand Up @@ -62,7 +98,10 @@ def create_embeddings():


@app.post("/completions/")
def chat_completion(req: ChatCompletionRequestData) -> str:
def chat_completion(req: ChatCompletionRequestData, request: Request, hdrs: WebIdDPoPInfoHeader = Depends(WebIdDPoPInfoHeader)) -> str:
if not checkIdentity(request):
raise HTTPException(400, detail='WebID attestation failed')

selected_model_idx = -1
for idx, llm in enumerate(config["llms"]):
if llm["model"] == req.model:
Expand Down
Loading

0 comments on commit 300f379

Please sign in to comment.