Skip to content

Commit

Permalink
Migrating Raw SQL Statements to Prepared Statements (#87) (#88)
Browse files Browse the repository at this point in the history
* #OBS-I148: Migration of SQL queries to prepared statement to avoid the SQL injection

* #OBS-I148: Migration of SQL queries to prepared statement to avoid the SQL injection

* #OBS-I148: Removed the unwanted imports

* #OBS-I148: System Config Changes - Converted from raw query to prepared statements

Co-authored-by: Ravi Mula <ravismula@users.noreply.github.com>
  • Loading branch information
manjudr and ravismula authored Aug 7, 2024
1 parent 3aafbd2 commit 65074e3
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import org.sunbird.obsrv.model.DatasetModels._
import org.sunbird.obsrv.model.{DatasetStatus, TransformMode}

import java.io.File
import java.sql.{ResultSet, Timestamp}
import java.sql.{PreparedStatement, ResultSet, Timestamp}

object DatasetRegistryService {
private val configFile = new File("/data/flink/conf/baseconfig.conf")
Expand Down Expand Up @@ -42,22 +42,28 @@ object DatasetRegistryService {
}

def readDataset(id: String): Option[Dataset] = {

val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
var resultSet: ResultSet = null
try {
val rs = postgresConnect.executeQuery(s"SELECT * FROM datasets where id='$id'")
if (rs.next()) {
Some(parseDataset(rs))
val query = "SELECT * FROM datasets WHERE id = ?"
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setString(1, id)
resultSet = postgresConnect.executeQuery(preparedStatement = preparedStatement)
if (resultSet.next()) {
Some(parseDataset(resultSet))
} else {
None
}
} finally {
if (resultSet != null) resultSet.close()
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}

def readAllDatasetSourceConfig(): Option[List[DatasetSourceConfig]] = {

def readAllDatasetSourceConfig(): Option[List[DatasetSourceConfig]] = {
val postgresConnect = new PostgresConnect(postgresConfig)
try {
val rs = postgresConnect.executeQuery("SELECT * FROM dataset_source_config")
Expand All @@ -70,16 +76,23 @@ object DatasetRegistryService {
}
}

def readDatasetSourceConfig(datasetId: String): Option[List[DatasetSourceConfig]] = {

def readDatasetSourceConfig(datasetId: String): Option[List[DatasetSourceConfig]] = {
val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
var resultSet: ResultSet = null
try {
val rs = postgresConnect.executeQuery(s"SELECT * FROM dataset_source_config where dataset_id='$datasetId'")
Option(Iterator.continually((rs, rs.next)).takeWhile(f => f._2).map(f => f._1).map(result => {
val query = "SELECT * FROM dataset_source_config WHERE dataset_id = ?"
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setString(1, datasetId)
resultSet = postgresConnect.executeQuery(preparedStatement = preparedStatement)
Option(Iterator.continually((resultSet, resultSet.next)).takeWhile(f => f._2).map(f => f._1).map(result => {
val datasetSourceConfig = parseDatasetSourceConfig(result)
datasetSourceConfig
}).toList)
} finally {
if (resultSet != null) resultSet.close()
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}
Expand All @@ -99,14 +112,20 @@ object DatasetRegistryService {
}

def readDatasources(datasetId: String): Option[List[DataSource]] = {

val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
var resultSet: ResultSet = null
try {
val rs = postgresConnect.executeQuery(s"SELECT * FROM datasources where dataset_id='$datasetId'")
Option(Iterator.continually((rs, rs.next)).takeWhile(f => f._2).map(f => f._1).map(result => {
val query = "SELECT * FROM datasources WHERE dataset_id = ?"
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setString(1, datasetId)
resultSet = postgresConnect.executeQuery(preparedStatement = preparedStatement)
Option(Iterator.continually((resultSet, resultSet.next)).takeWhile(f => f._2).map(f => f._1).map(result => {
parseDatasource(result)
}).toList)
} finally {
if (resultSet != null) resultSet.close()
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}
Expand All @@ -123,33 +142,65 @@ object DatasetRegistryService {
}

def updateDatasourceRef(datasource: DataSource, datasourceRef: String): Int = {
val query = s"UPDATE datasources set datasource_ref = '$datasourceRef' where datasource='${datasource.datasource}' and dataset_id='${datasource.datasetId}'"
updateRegistry(query)
val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
val query = "UPDATE datasources SET datasource_ref = ? WHERE datasource = ? AND dataset_id = ?"
try {
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setString(1, datasourceRef)
preparedStatement.setString(2, datasource.datasource)
preparedStatement.setString(3, datasource.datasetId)
postgresConnect.executeUpdate(preparedStatement)
} finally {
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}

def updateConnectorStats(id: String, lastFetchTimestamp: Timestamp, records: Long): Int = {
val query = s"UPDATE dataset_source_config SET connector_stats = coalesce(connector_stats, '{}')::jsonb || " +
s"jsonb_build_object('records', COALESCE(connector_stats->>'records', '0')::int + '$records'::int) || " +
s"jsonb_build_object('last_fetch_timestamp', '${lastFetchTimestamp}'::timestamp) || " +
s"jsonb_build_object('last_run_timestamp', '${new Timestamp(System.currentTimeMillis())}'::timestamp) WHERE id = '$id';"
updateRegistry(query)
val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
val query = "UPDATE dataset_source_config SET connector_stats = COALESCE(connector_stats, '{}')::jsonb || jsonb_build_object('records', COALESCE(connector_stats->>'records', '0')::int + ? ::int) || jsonb_build_object('last_fetch_timestamp', ? ::timestamp) || jsonb_build_object('last_run_timestamp', ? ::timestamp) WHERE id = ?;"
try {
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setString(1, records.toString)
preparedStatement.setTimestamp(2, lastFetchTimestamp)
preparedStatement.setTimestamp(3, new Timestamp(System.currentTimeMillis()))
preparedStatement.setString(4, id)
postgresConnect.executeUpdate(preparedStatement)
} finally {
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}


def updateConnectorDisconnections(id: String, disconnections: Int): Int = {
val query = s"UPDATE dataset_source_config SET connector_stats = jsonb_set(coalesce(connector_stats, '{}')::jsonb, '{disconnections}','$disconnections') WHERE id = '$id'"
updateRegistry(query)
val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
val query = "UPDATE dataset_source_config SET connector_stats = jsonb_set(coalesce(connector_stats, '{}')::jsonb, '{disconnections}', to_jsonb(?)) WHERE id = ?"
try {
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setInt(1, disconnections)
preparedStatement.setString(2, id)
postgresConnect.executeUpdate(preparedStatement)
} finally {
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}

def updateConnectorAvgBatchReadTime(id: String, avgReadTime: Long): Int = {
val query = s"UPDATE dataset_source_config SET connector_stats = jsonb_set(coalesce(connector_stats, '{}')::jsonb, '{avg_batch_read_time}','$avgReadTime') WHERE id = '$id'"
updateRegistry(query)
}

private def updateRegistry(query: String): Int = {
val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
val query = "UPDATE dataset_source_config SET connector_stats = jsonb_set(coalesce(connector_stats, '{}')::jsonb, '{avg_batch_read_time}', to_jsonb(?)) WHERE id = ?"
try {
postgresConnect.executeUpdate(query)
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setLong(1, avgReadTime)
preparedStatement.setString(2, id)
postgresConnect.executeUpdate(preparedStatement)
} finally {
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import org.sunbird.obsrv.core.model.Models.SystemSetting
import org.sunbird.obsrv.core.util.{PostgresConnect, PostgresConnectionConfig}

import java.io.File
import java.sql.ResultSet
import java.sql.{PreparedStatement, ResultSet}

object SystemConfig {

Expand Down Expand Up @@ -102,10 +102,17 @@ object SystemConfigService {
@throws[Exception]
def getSystemSetting(key: String): Option[SystemSetting] = {
val postgresConnect = new PostgresConnect(postgresConfig)
var preparedStatement: PreparedStatement = null
var rs: ResultSet = null
val query = "SELECT * FROM system_settings WHERE key = ?"
preparedStatement = postgresConnect.prepareStatement(query)
preparedStatement.setString(1, key)
try {
val rs = postgresConnect.executeQuery(s"SELECT * FROM system_settings WHERE key = '$key'")
rs = postgresConnect.executeQuery(preparedStatement = preparedStatement)
if (rs.next) Option(parseSystemSetting(rs)) else None
} finally {
if (rs != null) rs.close()
if (preparedStatement != null) preparedStatement.close()
postgresConnect.closeConnection()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.sunbird.obsrv.core.util
import org.postgresql.ds.PGSimpleDataSource
import org.slf4j.LoggerFactory

import java.sql.{Connection, ResultSet, SQLException, Statement}
import java.sql.{Connection, PreparedStatement, ResultSet, SQLException, Statement}

final case class PostgresConnectionConfig(user: String, password: String, database: String, host: String, port: Int, maxConnections: Int)

Expand Down Expand Up @@ -71,6 +71,41 @@ class PostgresConnect(config: PostgresConnectionConfig) {
// $COVERAGE-ON$
}

def prepareStatement(query: String): PreparedStatement = {
try {
connection.prepareStatement(query)
} catch {
case ex: SQLException =>
ex.printStackTrace()
logger.error("PostgresConnect:prepareStatement() - Exception", ex)
reset()
connection.prepareStatement(query)
}
}

def executeUpdate(preparedStatement: PreparedStatement): Int = {
try {
preparedStatement.executeUpdate()
} catch {
case ex: SQLException =>
ex.printStackTrace()
logger.error("PostgresConnect:executeUpdate():PreparedStatement - Exception", ex)
reset()
preparedStatement.executeUpdate()
}
}

def executeQuery(preparedStatement: PreparedStatement): ResultSet = {
try {
preparedStatement.executeQuery()
} catch {
case ex: SQLException =>
logger.error("PostgresConnect:execute():PreparedStatement - Exception", ex)
reset()
preparedStatement.executeQuery()
}
}

def executeQuery(query:String):ResultSet = statement.executeQuery(query)
}

Expand Down

0 comments on commit 65074e3

Please sign in to comment.