Skip to content

Commit

Permalink
Database client name fix, and more rules for malformed transactions (#17
Browse files Browse the repository at this point in the history
)

- Fix the name of a database client (the linter was failing silently on
TypeORM before)
- Add some TODOs
- Add some rules for checking if a transaction is malformed
- Now not running SQL injection over every method, only over
transactions
  • Loading branch information
CaspianA1 authored Aug 12, 2024
1 parent 124a71d commit c8f1947
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 42 deletions.
41 changes: 33 additions & 8 deletions dbos-rules.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ function makeSqlInjectionCode(code: string, sqlClient: string): string {
$executeRawUnsafe(query: string, ...values: any[]) {}
}
class PoolClient {
query(query: string, ...values: any[]) {}
class EntityManager {
query<T extends unknown[]>(query: string, parameters?: T) {}
}
class TypeORMEntityManager {
query<T extends unknown[]>(query: string, parameters?: T) {}
class PoolClient {
query(query: string, ...values: any[]) {}
}
function Transaction(target?: any, key?: any, descriptor?: any): any {
Expand Down Expand Up @@ -348,15 +348,15 @@ const testSet: TestSet = [
"PoolClient"
),

// Failure test #8 (testing `TypeORMEntityManager`)
// Failure test #8 (testing `EntityManager`)
makeSqlInjectionFailureTest(`
ctxt.client.query("foo" + (5).toString());
`,
Array(1).fill("sqlInjection"),
"TypeORMEntityManager"
"EntityManager"
),

// Failure test #9 (testing not using `TransactionContext`)
// Failure test #9 (testing not using `TransactionContext`, and malformed transactions)
makeSqlInjectionFailureTest(`
ctxt;
Expand Down Expand Up @@ -387,8 +387,33 @@ const testSet: TestSet = [
// But this one does
ctxt.client.raw(bob.baz2);
//////////
// Testing transactions without params
class Other2 {
@Transaction()
myInvalidTransactionWithoutParams() {}
}
// Testing transactions without a specified client type
class Other3 {
@Transaction()
myInvalidTransactionWithoutTypeParam(ctxt: TransactionContext) {}
}
class InvalidDatabaseClient {}
// Testing transactions with an invalid client type
class Other4 {
@Transaction()
myInvalidTransactionWithoutTypeParam(ctxt: TransactionContext<InvalidDatabaseClient>) {}
}
`,
Array(1).fill("transactionDoesntUseTheDatabase")
[
"transactionDoesntUseTheDatabase", "transactionHasNoParameters",
"transactionContextHasNoTypeArguments", "transactionContextHasInvalidClientType"
]
),

// Failure test #10 (a simpler object test)
Expand Down
100 changes: 69 additions & 31 deletions dbos-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ const awaitableTypes = new Set(["WorkflowContext"]); // Awaitable in determinist

// This maps the ORM client name to a list of raw SQL query calls to check
const ormClientInfoForRawSqlQueries: Map<string, string[]> = new Map([
["PoolClient", ["query"]],
["PrismaClient", ["$queryRawUnsafe", "$executeRawUnsafe"]],
["TypeORMEntityManager", ["query"]],
["Knex", ["raw"]]
["Knex", ["raw"]], // For Knex
["PrismaClient", ["$queryRawUnsafe", "$executeRawUnsafe"]], // For Prisma
["EntityManager", ["query"]], // For TypeORM
["PoolClient", ["query"]], // This is supported in `dbos-transact` (see `user_database.ts`, but not sure what ORM this corresponds to)
// ["PgDatabase", []], // For Drizzle (TODO: add full support for this)
]);

const assignmentTokenKinds = new Set([
Expand Down Expand Up @@ -107,8 +108,12 @@ and accesses via brackets (e.g. \`a["b"]\`) only succeed when every field in the

// The keys are the ids, and the values are the messages themselves
return new Map([
["sqlInjection", `Possible SQL injection detected. The parameter to the query call site traces back to the nonliteral on line {{ lineNumber }}: '{{ theExpression }}'\n${sqlInjectionNotes}`],
["transactionHasNoParameters", "This transaction has no parameters; add a `TransactionContext` parameter"],
["transactionContextHasNoTypeArguments", "The context passed to this transaction has no type arguments; add one to specify the database client"],
["transactionContextHasInvalidClientType", "The database client type `{{ clientType }}` used here is not recognized by the linter; consult the DBOS docs to find a supported one"],
["transactionDoesntUseTheDatabase", "This transaction does not use the database (via its `client` field). Consider using a communicator or a normal function"],

["sqlInjection", `Possible SQL injection detected. The parameter to the query call site traces back to the nonliteral on line {{ lineNumber }}: \`{{ theExpression }}\`\n${sqlInjectionNotes}`],
["globalMutation", "Deterministic DBOS operations (e.g. workflow code) should not mutate global variables; it can lead to non-reproducible behavior"],
["awaitingOnNotAllowedType", awaitMessage],
["Date", makeDateMessage("`Date()` or `new Date()`")],
Expand Down Expand Up @@ -156,10 +161,13 @@ TODO (requests from others, and general things for me to do):
- Chuck gave a suggestion to allow some function calls for LR-values; and do this by finding a way to mark them as constant
- Alex gave me this suggestion from a user: resolve many promises in parallel (with `Promise.all`)
From me:
- Run this over `dbos-transact`
- Maybe track type and variable aliasing somewhere, somehow (if needed)
- Mark some simple function calls as being constant (this could quickly spiral in terms of complexity)
- Add full Drizzle support!
*/

////////// These are some utility functions
Expand Down Expand Up @@ -187,6 +195,18 @@ function getSymbol(nodeOrType: Node | Type): Maybe<Symbol> {
return symbol;
}

function getTypeName(nodeOrType: Node | Type): string {
if (nodeOrType instanceof Node) {
// If it's a literal type, it'll get the base type; otherwise, nothing happens
const type = nodeOrType.getType().getBaseTypeOfLiteralType();
const maybeSymbol = getSymbol(type);
return maybeSymbol?.getName() ?? type.getText(nodeOrType);
}
else {
return getSymbol(nodeOrType)?.getName() ?? nodeOrType.getText();
}
}

function getRefsToNodeOrSymbol(nodeOrSymbol: Node | Symbol): Node[] {
let maybeSymbol = nodeOrSymbol instanceof Node ? getSymbol(nodeOrSymbol) : nodeOrSymbol;

Expand Down Expand Up @@ -288,11 +308,10 @@ const callsBannedFunction: ErrorChecker = (node, _fnDecl, _isLocal) => {

// TODO: match against `.then` as well (with a promise object preceding it)
const awaitsOnNotAllowedType: ErrorChecker = (node, _fnDecl, _isLocal) => {

// If the valid type set and arg type set intersect, then there's a valid type in the args.
function validTypeExistsInFunctionCallParams(functionCall: CallExpression, validTypes: Set<string>): boolean {
// I'd like to use `isDisjointFrom` here, but it doesn't seem to be available, for some reason
const argTypes = functionCall.getArguments().map(getTypeNameForTsMorphNode);
const argTypes = functionCall.getArguments().map(getTypeName);
return argTypes.some((argType) => validTypes.has(argType));
}

Expand All @@ -316,8 +335,7 @@ const awaitsOnNotAllowedType: ErrorChecker = (node, _fnDecl, _isLocal) => {

//////////

const typeName = getTypeNameForTsMorphNode(lhs);
const awaitingOnAllowedType = awaitableTypes.has(typeName);
const awaitingOnAllowedType = awaitableTypes.has(getTypeName(lhs));

if (!awaitingOnAllowedType) {
/* We should be allowed to await if we call a function that passes
Expand Down Expand Up @@ -567,14 +585,15 @@ function checkCallForInjection(callParam: Node): Maybe<ErrorMessageIdWithFormatD
if (ref === node) continue;

else if (!Node.isVariableDeclaration(ref)) {
panic("Unknown structure of assignment value symbol for shorthand property assignment!");
debugLog("Unknown structure of assignment value symbol for shorthand property assignment!");
continue;
}

const initializer = ref.getInitializer();
if (initializer !== Nothing && !isLR(initializer)) return false;
}

debugLog(`No refs exist pointing to this shorthand property assignment: ${node.getText()}`);
debugLog(`No refs exist pointing to this shorthand property assignment: '${node.getText()}'`);
return true;
}
// TODO: support spread assignments
Expand Down Expand Up @@ -618,7 +637,7 @@ function checkCallForInjection(callParam: Node): Maybe<ErrorMessageIdWithFormatD
}
}

// If it's a raw SQL injection callsite, then this returns the argument to examine
// If it's a raw SQL injection callsite, then this returns the argument to examine.
function maybeGetArgFromRawSqlCallSite(callExpr: CallExpression): Maybe<Node> {
const callExprWithoutParams = callExpr.getExpression();
const args = callExpr.getArguments();
Expand All @@ -629,12 +648,19 @@ function maybeGetArgFromRawSqlCallSite(callExpr: CallExpression): Maybe<Node> {
// `client.<callName>`, or `ctxt.client.<callName>`, and so on with the prefixes
const identifiers = callExprWithoutParams.getDescendantsOfKind(SyntaxKind.Identifier);

const identifierTypeNames = identifiers.map(getTypeNameForTsMorphNode);
if (identifierTypeNames.length < 2) return; // Can't recognize a raw SQL call for 0 or 1 identifiers
if (identifiers.length <= 1) {
debugLog(`Cannot recognize a raw SQL call from this here: '${callExpr.getText()}'`);
return;
}

const identifierTypeNames = identifiers.map(getTypeName);
const expectedClient = identifierTypeNames[identifierTypeNames.length - 2];
const callNames = ormClientInfoForRawSqlQueries.get(expectedClient);
if (callNames === Nothing) return;

if (callNames === Nothing) {
debugLog(`Unrecognized database client: '${expectedClient}'`);
return;
}

const expectedRawQueryCall = identifiers[identifiers.length - 1].getText();

Expand All @@ -651,21 +677,42 @@ const isSqlInjection: ErrorChecker = (node, _fnDecl, _isLocal) => {
return checkCallForInjection(maybeArg);
}
}
};
}

////////// This code is for detecting useless transactions
////////// This code is for detecting useless/malformed transactions

/* Note: this may result in false negatives for nested closures that capture the transaction context's client,
and when you call helper functions that you pass the context object to, but that helper function does nothing. */
const transactionDoesntUseTheDatabase: ErrorChecker = (node, fnDecl, _isLocal) => {
const transactionIsMalformed: ErrorChecker = (node, fnDecl, _isLocal) => {
if (node !== fnDecl) return; // Only analyze the whole function

////////// Step 1: check if the transaction has no parameters

const params = fnDecl.getParameters();
if (params.length === 0) return; // In this case, not a valid transaction
if (params.length === 0) return "transactionHasNoParameters";

////////// Step 2: check if the transaction context has no type arguments

const transactionContext = params[0];
const transactionContextSymbol = getSymbol(transactionContext); // The first param should be the transaction context
if (transactionContextSymbol === Nothing) return; // No symbol for the first param -> should not analyze
const typeArgs = transactionContext.getType().getTypeArguments();
if (typeArgs.length === 0) return "transactionContextHasNoTypeArguments";

////////// Step 3: check if the database client used is unrecognized

const clientType = getTypeName(typeArgs[0]);

if (!ormClientInfoForRawSqlQueries.has(clientType)) {
return ["transactionContextHasInvalidClientType", {clientType: clientType}];
}

////////// Step 4: check if the transaction context is never used

const transactionContextSymbol = getSymbol(transactionContext);

if (transactionContextSymbol === Nothing) {
debugLog("No symbol was ever found for the transaction context!");
return; // No symbol for the first param -> should not analyze
}

let foundDatabaseUsage = false;

Expand Down Expand Up @@ -703,8 +750,7 @@ First field: a set of method decorators to match on (if `Nothing`, then match on
Second field: a list of error checkers to run.
*/
const decoratorSetErrorCheckerMapping: [Maybe<Set<string>>, ErrorChecker[]][] = [
[Nothing, [isSqlInjection]], // Checking for SQL injection here (all functions)
[new Set(["Transaction"]), [transactionDoesntUseTheDatabase]], // Checking for useless transactions here
[new Set(["Transaction"]), [isSqlInjection, transactionIsMalformed]], // Checking for SQL injection and malformed transactions here
[new Set(["Workflow"]), [mutatesGlobalVariable, callsBannedFunction, awaitsOnNotAllowedType]] // Checking for nondeterminism here
];

Expand Down Expand Up @@ -816,14 +862,6 @@ function makeEslintNode(tsMorphNode: Node): EslintNode {
return eslintNode;
}

function getTypeNameForTsMorphNode(tsMorphNode: Node): string {
// If it's a literal type, it'll get the base type; otherwise, nothing happens
const type = tsMorphNode.getType().getBaseTypeOfLiteralType();

const maybeSymbol = getSymbol(type);
return maybeSymbol?.getName() ?? type.getText();
}

// This is just for making sure that the unit tests are well constructed (not used when deployed)
function checkDiagnostics(node: Node) {
const project = new Project({});
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@dbos-inc/eslint-plugin",
"version": "3.2.0",
"version": "3.3.0",
"description": "eslint plugin for DBOS SDK",
"license": "MIT",
"repository": {
Expand Down

0 comments on commit c8f1947

Please sign in to comment.