Skip to content

Commit

Permalink
feat: support user defined prepared statement placeholder (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
whhe authored Mar 20, 2024
1 parent f65041b commit 969383e
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.flink.table.catalog.UniqueConstraint;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.util.function.SerializableFunction;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand All @@ -39,6 +40,7 @@ public class TableInfo implements Table {
private final List<String> fieldNames;
private final Map<String, Integer> fieldIndexMap;
private final List<LogicalType> dataTypes;
private final SerializableFunction<String, String> placeholderFunc;

public TableInfo(TableId tableId, ResolvedSchema resolvedSchema) {
this(
Expand All @@ -47,14 +49,16 @@ public TableInfo(TableId tableId, ResolvedSchema resolvedSchema) {
resolvedSchema.getColumnNames(),
resolvedSchema.getColumnDataTypes().stream()
.map(DataType::getLogicalType)
.collect(Collectors.toList()));
.collect(Collectors.toList()),
null);
}

public TableInfo(
@Nonnull TableId tableId,
@Nullable List<String> primaryKey,
@Nonnull List<String> fieldNames,
@Nonnull List<LogicalType> dataTypes) {
@Nonnull List<LogicalType> dataTypes,
@Nullable SerializableFunction<String, String> placeholderFunc) {
this.tableId = tableId;
this.primaryKey = primaryKey;
this.fieldNames = fieldNames;
Expand All @@ -63,6 +67,7 @@ public TableInfo(
IntStream.range(0, fieldNames.size())
.boxed()
.collect(Collectors.toMap(fieldNames::get, i -> i));
this.placeholderFunc = placeholderFunc;
}

@Override
Expand All @@ -88,6 +93,10 @@ public List<LogicalType> getDataTypes() {
return dataTypes;
}

public SerializableFunction<String, String> getPlaceholderFunc() {
return placeholderFunc;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package com.oceanbase.connector.flink.dialect;

import org.apache.flink.util.function.SerializableFunction;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -44,36 +46,56 @@ default String getFullTableName(@Nonnull String schemaName, @Nonnull String tabl
return String.format("%s.%s", quoteIdentifier(schemaName), quoteIdentifier(tableName));
}

/**
* Gets the placeholder for prepared statement
*
* @param fieldName field name
* @param placeholderFunc user defined placeholder function
* @return the placeholder for prepared statement
*/
default String getPlaceholder(
@Nonnull String fieldName,
@Nullable SerializableFunction<String, String> placeholderFunc) {
return placeholderFunc != null ? placeholderFunc.apply(fieldName) : "?";
}

/**
* Gets the upsert statement
*
* @param schemaName schema name
* @param tableName table name
* @param fieldNames field names list
* @param uniqueKeyFields unique key field names list
* @param placeholderFunc function used to get placeholder for the fields
* @return the statement string
*/
String getUpsertStatement(
@Nonnull String schemaName,
@Nonnull String tableName,
@Nonnull List<String> fieldNames,
@Nonnull List<String> uniqueKeyFields);
@Nonnull List<String> uniqueKeyFields,
@Nullable SerializableFunction<String, String> placeholderFunc);

/**
* Gets the insert statement
*
* @param schemaName schema name
* @param tableName table name
* @param fieldNames field names list
* @param placeholderFunc function used to get placeholder for the fields
* @return the statement string
*/
default String getInsertIntoStatement(
@Nonnull String schemaName,
@Nonnull String tableName,
@Nonnull List<String> fieldNames) {
@Nonnull List<String> fieldNames,
@Nullable SerializableFunction<String, String> placeholderFunc) {
String columns =
fieldNames.stream().map(this::quoteIdentifier).collect(Collectors.joining(", "));
String placeholders = String.join(", ", Collections.nCopies(fieldNames.size(), "?"));
String placeholders =
fieldNames.stream()
.map(f -> getPlaceholder(f, placeholderFunc))
.collect(Collectors.joining(", "));
return "INSERT INTO "
+ getFullTableName(schemaName, tableName)
+ "("
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package com.oceanbase.connector.flink.dialect;

import org.apache.flink.util.function.SerializableFunction;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -35,13 +38,14 @@ public String getUpsertStatement(
@Nonnull String schemaName,
@Nonnull String tableName,
@Nonnull List<String> fieldNames,
@Nonnull List<String> uniqueKeyFields) {
@Nonnull List<String> uniqueKeyFields,
@Nullable SerializableFunction<String, String> placeholderFunc) {
String updateClause =
fieldNames.stream()
.filter(f -> !uniqueKeyFields.contains(f))
.map(f -> quoteIdentifier(f) + "=VALUES(" + quoteIdentifier(f) + ")")
.collect(Collectors.joining(", "));
return getInsertIntoStatement(schemaName, tableName, fieldNames)
return getInsertIntoStatement(schemaName, tableName, fieldNames, placeholderFunc)
+ " ON DUPLICATE KEY UPDATE "
+ updateClause;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package com.oceanbase.connector.flink.dialect;

import org.apache.flink.util.function.SerializableFunction;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -35,10 +38,11 @@ public String getUpsertStatement(
@Nonnull String schemaName,
@Nonnull String tableName,
@Nonnull List<String> fieldNames,
@Nonnull List<String> uniqueKeyFields) {
@Nonnull List<String> uniqueKeyFields,
@Nullable SerializableFunction<String, String> placeholderFunc) {
String sourceFields =
fieldNames.stream()
.map(f -> "? AS " + quoteIdentifier(f))
.map(f -> getPlaceholder(f, placeholderFunc) + " AS " + quoteIdentifier(f))
.collect(Collectors.joining(", "));

String onClause =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ public synchronized void flush(List<DataChangeRecord> records) throws Exception
dialect.getInsertIntoStatement(
tableId.getSchemaName(),
tableId.getTableName(),
tableInfo.getFieldNames()),
tableInfo.getFieldNames(),
tableInfo.getPlaceholderFunc()),
tableInfo.getFieldNames(),
upsertBatch);
} else {
Expand All @@ -160,7 +161,8 @@ public synchronized void flush(List<DataChangeRecord> records) throws Exception
tableId.getSchemaName(),
tableId.getTableName(),
tableInfo.getFieldNames(),
tableInfo.getKey()),
tableInfo.getKey(),
tableInfo.getPlaceholderFunc()),
tableInfo.getFieldNames(),
upsertBatch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

public class OceanBaseConnectorITCase extends OceanBaseTestBase {
Expand Down Expand Up @@ -291,6 +292,91 @@ private RowData rowData(int id, String name, String description, double weight)
DecimalData.fromBigDecimal(new BigDecimal(weight), 20, 10));
}

@Test
public void testGis() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);

String tableName = "gis_types";
Map<String, String> options = getOptions();
options.put("table-name", tableName);

OceanBaseConnectorOptions connectorOptions = new OceanBaseConnectorOptions(options);
OceanBaseConnectionProvider connectionProvider =
new OceanBaseConnectionProvider(connectorOptions);
OceanBaseSink<RowData> sink =
new OceanBaseSink<>(
connectorOptions,
null,
new OceanBaseRowDataSerializationSchema(
new TableInfo(
new TableId(
connectionProvider.getDialect()::getFullTableName,
connectorOptions.getSchemaName(),
connectorOptions.getTableName()),
Collections.singletonList("id"),
Arrays.asList(
"id",
"point_c",
"geometry_c",
"linestring_c",
"polygon_c",
"multipoint_c",
"multiline_c",
"multipolygon_c",
"geometrycollection_c"),
Arrays.asList(
DataTypes.INT().getLogicalType(),
DataTypes.STRING().getLogicalType(),
DataTypes.STRING().getLogicalType(),
DataTypes.STRING().getLogicalType(),
DataTypes.STRING().getLogicalType(),
DataTypes.STRING().getLogicalType(),
DataTypes.STRING().getLogicalType(),
DataTypes.STRING().getLogicalType(),
DataTypes.STRING().getLogicalType()),
x -> "id".equals(x) ? "?" : "ST_GeomFromText(?)")),
DataChangeRecord.KeyExtractor.simple(),
new OceanBaseRecordFlusher(connectorOptions));

List<String> values =
Arrays.asList(
"POINT(1 1)",
"POLYGON((1 1,2 1,2 2,1 2,1 1))",
"LINESTRING(3 0,3 3,3 5)",
"POLYGON((1 1,2 1,2 2,1 2,1 1))",
"MULTIPOINT((1 1),(2 2))",
"MULTILINESTRING((1 1,2 2,3 3),(4 4,5 5))",
"MULTIPOLYGON(((0 0,10 0,10 10,0 10,0 0)),((5 5,7 5,7 7,5 7,5 5)))",
"GEOMETRYCOLLECTION(POINT(10 10),POINT(30 30),LINESTRING(15 15,20 20))");

GenericRowData rowData = new GenericRowData(RowKind.INSERT, values.size() + 1);
rowData.setField(0, 1);
for (int i = 0; i < values.size(); i++) {
rowData.setField(i + 1, StringData.fromString(values.get(i)));
}

env.fromElements((RowData) rowData).sinkTo(sink);
env.execute();

waitForTableCount(tableName, 1);
List<String> actual =
queryTable(
tableName,
Arrays.asList(
"id",
"ST_AsWKT(point_c)",
"ST_AsWKT(geometry_c)",
"ST_AsWKT(linestring_c)",
"ST_AsWKT(polygon_c)",
"ST_AsWKT(multipoint_c)",
"ST_AsWKT(multiline_c)",
"ST_AsWKT(multipolygon_c)",
"ST_AsWKT(geometrycollection_c)"));

assertEquals(actual.get(0), "1," + String.join(",", values));
}

@Test
public void testDirectLoadSink() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
Expand Down Expand Up @@ -399,10 +485,16 @@ private void waitForTableCount(String tableName, int expectedCount)
}

public List<String> queryTable(String tableName) throws SQLException {
return queryTable(tableName, Collections.singletonList("*"));
}

public List<String> queryTable(String tableName, List<String> fields) throws SQLException {
List<String> result = new ArrayList<>();
try (Connection connection = getConnection();
Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery("SELECT * FROM " + tableName);
ResultSet rs =
statement.executeQuery(
"SELECT " + String.join(", ", fields) + " FROM " + tableName);
ResultSetMetaData metaData = rs.getMetaData();

while (rs.next()) {
Expand Down
14 changes: 14 additions & 0 deletions flink-connector-oceanbase/src/test/resources/sql/init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,17 @@ CREATE TABLE products
description VARCHAR(512),
weight DECIMAL(20, 10)
);


CREATE TABLE gis_types
(
id INTEGER NOT NULL AUTO_INCREMENT PRIMARY KEY,
point_c POINT,
geometry_c GEOMETRY,
linestring_c LINESTRING,
polygon_c POLYGON,
multipoint_c MULTIPOINT,
multiline_c MULTILINESTRING,
multipolygon_c MULTIPOLYGON,
geometrycollection_c GEOMETRYCOLLECTION
)

0 comments on commit 969383e

Please sign in to comment.