Skip to content

Commit

Permalink
[FLINK-37084][python] Fix TimerRegistration concurrency issue in PyFl…
Browse files Browse the repository at this point in the history
…ink (#26004)
  • Loading branch information
suez1224 authored Jan 17, 2025
1 parent d8cd559 commit 82baf88
Show file tree
Hide file tree
Showing 33 changed files with 133 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public interface PythonFunctionRunner extends AutoCloseable {
/** Send the triggered timer to the Python function. */
void processTimer(byte[] timerData) throws Exception;

void drainUnregisteredTimers();

/**
* Retrieves the Python function result.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,25 @@ public Configuration getConfiguration() {

protected abstract PythonEnvironmentManager createPythonEnvironmentManager();

protected void drainUnregisteredTimers() {}

/**
* Advances the watermark of all managed timer services, potentially firing event time timers.
* It also ensures that the fired timers are processed in the Python user-defined functions.
*/
private void advanceWatermark(Watermark watermark) throws Exception {
if (getTimeServiceManager().isPresent()) {
InternalTimeServiceManager<?> timeServiceManager = getTimeServiceManager().get();
// make sure the registered timer are processed before advancing the watermark to
// ensure the timers could be triggered
drainUnregisteredTimers();
timeServiceManager.advanceWatermark(watermark);

while (!isBundleFinished()) {
invokeFinishBundle();
// make sure the registered timer are processed before advancing the watermark to
// ensure the timers could be triggered
drainUnregisteredTimers();
timeServiceManager.advanceWatermark(watermark);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ protected ProcessPythonEnvironmentManager createPythonEnvironmentManager() {
}
}

@Override
protected void drainUnregisteredTimers() {
pythonFunctionRunner.drainUnregisteredTimers();
}

protected void emitResults() throws Exception {
Tuple3<String, byte[], Integer> resultTuple;
while ((resultTuple = pythonFunctionRunner.pollResult()) != null && resultTuple.f2 != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public void open() throws Exception {
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
return new BeamDataStreamPythonFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
createPythonEnvironmentManager(),
STATELESS_FUNCTION_URN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ public void open() throws Exception {
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
return new BeamDataStreamPythonFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
createPythonEnvironmentManager(),
STATEFUL_FUNCTION_URN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ public void onProcessingTime(InternalTimer<Row, Object> timer) throws Exception
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
return new BeamDataStreamPythonFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
createPythonEnvironmentManager(),
STATEFUL_FUNCTION_URN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public void open() throws Exception {
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
return new BeamDataStreamPythonFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
createPythonEnvironmentManager(),
STATELESS_FUNCTION_URN,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.streaming.api.operators.python.process.timer;

public class TimerRegistrationAction {

private final TimerRegistration timerRegistration;

private final byte[] serializedTimerData;

private boolean isRegistered;

public TimerRegistrationAction(
TimerRegistration timerRegistration, byte[] serializedTimerData) {
this.timerRegistration = timerRegistration;
this.serializedTimerData = serializedTimerData;
this.isRegistered = false;
}

public void run() {
if (!isRegistered) {
timerRegistration.setTimer(serializedTimerData);
isRegistered = true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.python.env.process.ProcessPythonEnvironmentManager;
import org.apache.flink.python.metric.process.FlinkMetricContainer;
import org.apache.flink.python.util.ProtoUtils;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.OperatorStateBackend;
Expand Down Expand Up @@ -69,6 +70,7 @@ public class BeamDataStreamPythonFunctionRunner extends BeamPythonFunctionRunner
private final List<FlinkFnApi.UserDefinedDataStreamFunction> userDefinedDataStreamFunctions;

public BeamDataStreamPythonFunctionRunner(
Environment environment,
String taskName,
ProcessPythonEnvironmentManager environmentManager,
String headOperatorFunctionUrn,
Expand All @@ -86,6 +88,7 @@ public BeamDataStreamPythonFunctionRunner(
@Nullable FlinkFnApi.CoderInfoDescriptor timerCoderDescriptor,
Map<String, FlinkFnApi.CoderInfoDescriptor> sideOutputCoderDescriptors) {
super(
environment,
taskName,
environmentManager,
flinkMetricContainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
import org.apache.flink.python.env.process.ProcessPythonEnvironment;
import org.apache.flink.python.env.process.ProcessPythonEnvironmentManager;
import org.apache.flink.python.metric.process.FlinkMetricContainer;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.memory.OpaqueMemoryResource;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.streaming.api.operators.python.process.timer.TimerRegistration;
import org.apache.flink.streaming.api.operators.python.process.timer.TimerRegistrationAction;
import org.apache.flink.streaming.api.runners.python.beam.state.BeamStateRequestHandler;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.ShutdownHookUtil;
Expand Down Expand Up @@ -85,6 +87,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -190,7 +193,12 @@ public abstract class BeamPythonFunctionRunner implements PythonFunctionRunner {

private transient Thread shutdownHook;

private transient Environment environment;

private transient List<TimerRegistrationAction> unregisteredTimers;

public BeamPythonFunctionRunner(
Environment environment,
String taskName,
ProcessPythonEnvironmentManager environmentManager,
@Nullable FlinkMetricContainer flinkMetricContainer,
Expand All @@ -204,6 +212,7 @@ public BeamPythonFunctionRunner(
FlinkFnApi.CoderInfoDescriptor inputCoderDescriptor,
FlinkFnApi.CoderInfoDescriptor outputCoderDescriptor,
Map<String, FlinkFnApi.CoderInfoDescriptor> sideOutputCoderDescriptors) {
this.environment = environment;
this.taskName = Preconditions.checkNotNull(taskName);
this.environmentManager = Preconditions.checkNotNull(environmentManager);
this.flinkMetricContainer = flinkMetricContainer;
Expand Down Expand Up @@ -301,6 +310,8 @@ public void open(ReadableConfig config) throws Exception {
shutdownHook =
ShutdownHookUtil.addShutdownHook(
this, BeamPythonFunctionRunner.class.getSimpleName(), LOG);

unregisteredTimers = new LinkedList<>();
}

@Override
Expand Down Expand Up @@ -339,6 +350,14 @@ public void process(byte[] data) throws Exception {
mainInputReceiver.accept(WindowedValue.valueInGlobalWindow(data));
}

@Override
public void drainUnregisteredTimers() {
for (TimerRegistrationAction timerRegistrationAction : unregisteredTimers) {
timerRegistrationAction.run();
}
unregisteredTimers.clear();
}

@Override
public void processTimer(byte[] timerData) throws Exception {
if (timerInputReceiver == null) {
Expand Down Expand Up @@ -681,7 +700,15 @@ public void onCompleted(BeamFnApi.ProcessBundleResponse response) {

private TimerReceiverFactory createTimerReceiverFactory() {
BiConsumer<Timer<?>, TimerInternals.TimerData> timerDataConsumer =
(timer, timerData) -> timerRegistration.setTimer((byte[]) timer.getUserKey());
(timer, timerData) -> {
TimerRegistrationAction timerRegistrationAction =
new TimerRegistrationAction(
timerRegistration, (byte[]) timer.getUserKey());
unregisteredTimers.add(timerRegistrationAction);
environment
.getMainMailboxExecutor()
.execute(timerRegistrationAction::run, "PythonTimerRegistration");
};
return new TimerReceiverFactory(stageBundleFactory, timerDataConsumer, null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ public void processElement(StreamRecord<IN> element) throws Exception {
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws IOException {
return BeamTablePythonFunctionRunner.stateless(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
createPythonEnvironmentManager(),
getFunctionUrn(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ public void processElement(StreamRecord<RowData> element) throws Exception {
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
return BeamTablePythonFunctionRunner.stateful(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
createPythonEnvironmentManager(),
getFunctionUrn(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.python.env.process.ProcessPythonEnvironmentManager;
import org.apache.flink.python.metric.process.FlinkMetricContainer;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.streaming.api.runners.python.beam.BeamPythonFunctionRunner;
Expand Down Expand Up @@ -52,6 +53,7 @@ public class BeamTablePythonFunctionRunner extends BeamPythonFunctionRunner {
private final GeneratedMessageV3 userDefinedFunctionProto;

public BeamTablePythonFunctionRunner(
Environment environment,
String taskName,
ProcessPythonEnvironmentManager environmentManager,
String functionUrn,
Expand All @@ -65,6 +67,7 @@ public BeamTablePythonFunctionRunner(
FlinkFnApi.CoderInfoDescriptor inputCoderDescriptor,
FlinkFnApi.CoderInfoDescriptor outputCoderDescriptor) {
super(
environment,
taskName,
environmentManager,
flinkMetricContainer,
Expand Down Expand Up @@ -117,6 +120,7 @@ public void processTimer(byte[] timerData) throws Exception {
}

public static BeamTablePythonFunctionRunner stateless(
Environment environment,
String taskName,
ProcessPythonEnvironmentManager environmentManager,
String functionUrn,
Expand All @@ -127,6 +131,7 @@ public static BeamTablePythonFunctionRunner stateless(
FlinkFnApi.CoderInfoDescriptor inputCoderDescriptor,
FlinkFnApi.CoderInfoDescriptor outputCoderDescriptor) {
return new BeamTablePythonFunctionRunner(
environment,
taskName,
environmentManager,
functionUrn,
Expand All @@ -142,6 +147,7 @@ public static BeamTablePythonFunctionRunner stateless(
}

public static BeamTablePythonFunctionRunner stateful(
Environment environment,
String taskName,
ProcessPythonEnvironmentManager environmentManager,
String functionUrn,
Expand All @@ -155,6 +161,7 @@ public static BeamTablePythonFunctionRunner stateful(
FlinkFnApi.CoderInfoDescriptor inputCoderDescriptor,
FlinkFnApi.CoderInfoDescriptor outputCoderDescriptor) {
return new BeamTablePythonFunctionRunner(
environment,
taskName,
environmentManager,
functionUrn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ public void open() throws Exception {
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws Exception {
return new PassThroughStreamGroupWindowAggregatePythonFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
userDefinedFunctionInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ private static class PassThroughPythonStreamGroupAggregateOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughStreamAggregatePythonFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
userDefinedFunctionInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ private static class PassThroughPythonStreamGroupTableAggregateOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughStreamTableAggregatePythonFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
userDefinedFunctionInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ private static class PassThroughBatchArrowPythonGroupAggregateFunctionOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ private static class PassThroughBatchArrowPythonGroupWindowAggregateFunctionOper
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ private static class PassThroughBatchArrowPythonOverWindowAggregateFunctionOpera
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ private static class PassThroughStreamArrowPythonGroupWindowAggregateFunctionOpe
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ private static class PassThroughStreamArrowPythonProcTimeBoundedRangeOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ private static class PassThroughStreamArrowPythonProcTimeBoundedRowsOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ private static class PassThroughStreamArrowPythonRowTimeBoundedRangeOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ private static class PassThroughStreamArrowPythonRowTimeBoundedRowsOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() {
return new PassThroughPythonAggregateFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ private static class PassThroughPythonScalarFunctionOperator
@Override
public PythonFunctionRunner createPythonFunctionRunner() throws IOException {
return new PassThroughPythonScalarFunctionRunner(
getContainingTask().getEnvironment(),
getRuntimeContext().getTaskInfo().getTaskName(),
PythonTestUtils.createTestProcessEnvironmentManager(),
udfInputType,
Expand Down
Loading

0 comments on commit 82baf88

Please sign in to comment.