Skip to content

Commit

Permalink
concord-server: refactor WebSocketChannelManager, allow message sourc…
Browse files Browse the repository at this point in the history
…es in plugins
  • Loading branch information
ibodrov committed Dec 23, 2024
1 parent 1238a6e commit 560c16a
Show file tree
Hide file tree
Showing 15 changed files with 250 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import com.walmartlabs.concord.server.boot.BackgroundTasks;
import com.walmartlabs.concord.server.sdk.rest.Resource;
import com.walmartlabs.concord.server.task.TaskScheduler;
import com.walmartlabs.concord.server.websocket.WebSocketChannelManager;
import com.walmartlabs.concord.server.websocket.MessageChannelManager;
import org.jooq.Configuration;

import javax.inject.Inject;
Expand All @@ -42,18 +42,18 @@ public class ServerResource implements Resource {

private final TaskScheduler taskScheduler;
private final BackgroundTasks backgroundTasks;
private final WebSocketChannelManager webSocketChannelManager;
private final MessageChannelManager messageChannelManager;
private final PingDao pingDao;

@Inject
public ServerResource(TaskScheduler taskScheduler,
BackgroundTasks backgroundTasks,
WebSocketChannelManager webSocketChannelManager,
MessageChannelManager messageChannelManager,
PingDao pingDao) {

this.taskScheduler = taskScheduler;
this.backgroundTasks = backgroundTasks;
this.webSocketChannelManager = webSocketChannelManager;
this.messageChannelManager = messageChannelManager;
this.pingDao = pingDao;
}

Expand All @@ -78,7 +78,7 @@ public VersionResponse version() {
public void maintenanceMode() {
backgroundTasks.stop();

webSocketChannelManager.shutdown();
messageChannelManager.shutdown();
taskScheduler.stop();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import com.walmartlabs.concord.server.queueclient.message.MessageType;
import com.walmartlabs.concord.server.queueclient.message.ProcessRequest;
import com.walmartlabs.concord.server.sdk.ProcessKey;
import com.walmartlabs.concord.server.websocket.MessageChannel;
import com.walmartlabs.concord.server.websocket.MessageChannelManager;
import com.walmartlabs.concord.server.websocket.WebSocketChannel;
import com.walmartlabs.concord.server.websocket.WebSocketChannelManager;
import org.jooq.DSLContext;

import javax.inject.Inject;
Expand All @@ -38,23 +39,24 @@
public class AgentManager {

private final AgentCommandsDao commandQueue;
private final WebSocketChannelManager channelManager;
private final MessageChannelManager channelManager;

@Inject
public AgentManager(AgentCommandsDao commandQueue,
WebSocketChannelManager channelManager) {
MessageChannelManager channelManager) {

this.commandQueue = commandQueue;
this.channelManager = channelManager;
}

public Collection<AgentWorkerEntry> getAvailableAgents() {
Map<WebSocketChannel, ProcessRequest> reqs = channelManager.getRequests(MessageType.PROCESS_REQUEST);
Map<MessageChannel, ProcessRequest> reqs = channelManager.getRequests(MessageType.PROCESS_REQUEST);
return reqs.entrySet().stream()
.filter(r -> r.getKey() instanceof WebSocketChannel) // TODO a better way
.map(r -> AgentWorkerEntry.builder()
.channelId(r.getKey().getChannelId())
.agentId(r.getKey().getAgentId())
.userAgent(r.getKey().getUserAgent())
.userAgent(((WebSocketChannel) r.getKey()).getUserAgent())
.capabilities(r.getValue().getCapabilities())
.build())
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@

import javax.annotation.Nullable;
import java.util.Map;
import java.util.UUID;

@Value.Immutable
@JsonInclude(JsonInclude.Include.NON_EMPTY)
@JsonSerialize(as = ImmutableAgentWorkerEntry.class)
@JsonDeserialize(as = ImmutableAgentWorkerEntry.class)
public interface AgentWorkerEntry {

UUID channelId();
String channelId();

@Nullable // backward-compatibility with old agent versions
String agentId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
import com.walmartlabs.concord.server.queueclient.message.CommandResponse;
import com.walmartlabs.concord.server.queueclient.message.MessageType;
import com.walmartlabs.concord.server.sdk.metrics.WithTimer;
import com.walmartlabs.concord.server.websocket.WebSocketChannel;
import com.walmartlabs.concord.server.websocket.WebSocketChannelManager;
import com.walmartlabs.concord.server.websocket.MessageChannel;
import com.walmartlabs.concord.server.websocket.MessageChannelManager;
import org.jooq.Configuration;
import org.jooq.DSLContext;
import org.slf4j.Logger;
Expand All @@ -61,12 +61,12 @@ public class Dispatcher extends PeriodicTask {
private static final int BATCH_SIZE = 10;

private final DispatcherDao dao;
private final WebSocketChannelManager channelManager;
private final MessageChannelManager channelManager;

@Inject
public Dispatcher(AgentConfiguration cfg,
DispatcherDao dao,
WebSocketChannelManager channelManager) {
MessageChannelManager channelManager) {

super(cfg.getCommandPollDelay().toMillis(), ERROR_DELAY);
this.dao = dao;
Expand All @@ -75,7 +75,7 @@ public Dispatcher(AgentConfiguration cfg,

@Override
protected boolean performTask() {
Map<WebSocketChannel, CommandRequest> requests = this.channelManager.getRequests(MessageType.COMMAND_REQUEST);
Map<MessageChannel, CommandRequest> requests = this.channelManager.getRequests(MessageType.COMMAND_REQUEST);
if (requests.isEmpty()) {
return false;
}
Expand Down Expand Up @@ -148,7 +148,6 @@ private AgentCommand findCandidate(CommandRequest req, List<AgentCommand> candid
}

private void sendResponse(Match match, AgentCommand response) {
WebSocketChannel channel = match.request.channel;
long correlationId = match.request.request.getCorrelationId();

CommandType type = CommandType.valueOf((String) response.getData().remove(Commands.TYPE_KEY));
Expand All @@ -157,7 +156,8 @@ private void sendResponse(Match match, AgentCommand response) {
payload.put("type", type.toString());
payload.putAll(response.getData());

boolean success = channelManager.sendResponse(channel.getChannelId(), CommandResponse.cancel(correlationId, payload));
MessageChannel channel = match.request.channel;
boolean success = channelManager.sendMessage(channel.getChannelId(), CommandResponse.cancel(correlationId, payload));
if (success) {
log.info("sendResponse ['{}'] -> done", correlationId);
} else {
Expand Down Expand Up @@ -223,25 +223,10 @@ private AgentCommand convert(AgentCommandsRecord r) {
}
}

private static final class Match {
private record Match(Request request, AgentCommand command) {

private final Request request;
private final AgentCommand command;

private Match(Request request, AgentCommand command) {
this.request = request;
this.command = command;
}
}

private static final class Request {

private final WebSocketChannel channel;
private final CommandRequest request;

private Request(WebSocketChannel channel, CommandRequest request) {
this.channel = channel;
this.request = request;
}
private record Request(MessageChannel channel, CommandRequest request) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.jooq.impl.DSL;
import org.jooq.util.postgres.PostgresDSL;

import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Named;
import java.time.Duration;
Expand Down Expand Up @@ -140,7 +141,7 @@ public void insert(DSLContext tx, ProcessKey processKey, ProcessStatus status, P
}
}

public void updateAgentId(DSLContext tx, ProcessKey processKey, String agentId, ProcessStatus status) {
public void updateAgentId(DSLContext tx, ProcessKey processKey, @Nullable String agentId, ProcessStatus status) {
UUID instanceId = processKey.getInstanceId();

int i = tx.update(PROCESS_QUEUE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.walmartlabs.concord.server.sdk.ProcessStatus;
import org.jooq.DSLContext;

import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Named;
import java.time.Duration;
Expand Down Expand Up @@ -193,7 +194,7 @@ public void updateAgentId(ProcessKey processKey, String agentId, ProcessStatus s
/**
* Updates the process' agent ID and status.
*/
public void updateAgentId(DSLContext tx, ProcessKey processKey, String agentId, ProcessStatus status) {
public void updateAgentId(DSLContext tx, ProcessKey processKey, @Nullable String agentId, ProcessStatus status) {
queueDao.updateAgentId(tx, processKey, agentId, status);
notifyStatusChange(tx, processKey, status);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@
import com.walmartlabs.concord.server.sdk.ProcessKey;
import com.walmartlabs.concord.server.sdk.ProcessStatus;
import com.walmartlabs.concord.server.sdk.metrics.WithTimer;
import com.walmartlabs.concord.server.websocket.MessageChannel;
import com.walmartlabs.concord.server.websocket.MessageChannelManager;
import com.walmartlabs.concord.server.websocket.WebSocketChannel;
import com.walmartlabs.concord.server.websocket.WebSocketChannelManager;
import org.jooq.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -77,7 +78,7 @@ public class Dispatcher extends PeriodicTask {

private final Locks locks;
private final DispatcherDao dao;
private final WebSocketChannelManager channelManager;
private final MessageChannelManager channelManager;
private final ProcessLogManager logManager;
private final ProcessQueueManager queueManager;
private final Set<Filter> filters;
Expand All @@ -93,7 +94,7 @@ public class Dispatcher extends PeriodicTask {
@Inject
public Dispatcher(Locks locks,
DispatcherDao dao,
WebSocketChannelManager channelManager,
MessageChannelManager channelManager,
ProcessLogManager logManager,
ProcessQueueManager queueManager,
Set<Filter> filters,
Expand Down Expand Up @@ -125,7 +126,7 @@ protected boolean performTask() {
// TODO the WebSocketChannelManager business can be replaced with an async jax-rs endpoint and an "inbox" queue

// grab the requests w/o responses
Map<WebSocketChannel, ProcessRequest> requests = this.channelManager.getRequests(MessageType.PROCESS_REQUEST);
Map<MessageChannel, ProcessRequest> requests = this.channelManager.getRequests(MessageType.PROCESS_REQUEST);
if (requests.isEmpty()) {
return false;
}
Expand Down Expand Up @@ -206,7 +207,8 @@ private List<Match> match(DSLContext tx, List<Request> requests) {
ProcessQueueEntry candidate = m.response;

// mark the process as STARTING
queueManager.updateAgentId(tx, candidate.key(), m.request.channel.getAgentId(), ProcessStatus.STARTING);
String agentId = m.request.channel.getAgentId();
queueManager.updateAgentId(tx, candidate.key(), agentId, ProcessStatus.STARTING);
}

return matches;
Expand Down Expand Up @@ -250,7 +252,6 @@ private boolean pass(DSLContext tx, ProcessQueueEntry e, List<ProcessQueueEntry>
}

private void sendResponse(Match match) {
WebSocketChannel channel = match.request.channel;
long correlationId = match.request.request.getCorrelationId();
ProcessQueueEntry item = match.response;

Expand All @@ -275,17 +276,23 @@ private void sendResponse(Match match) {
secret != null ? secret.secretName : null,
imports);

if (!channelManager.sendResponse(channel.getChannelId(), resp)) {
MessageChannel channel = match.request.channel;
if (!channelManager.sendMessage(channel.getChannelId(), resp)) {
log.warn("sendResponse ['{}'] -> failed", correlationId);
}

logManager.info(item.key(), "Acquired by: " + channel.getUserAgent());
// TODO a way to avoid instanceof here
String userAgent = channel instanceof WebSocketChannel ? ((WebSocketChannel) channel).getUserAgent() : null;
if (userAgent != null) {
logManager.info(item.key(), "Acquired by: " + userAgent);
}
} catch (Exception e) {
log.error("sendResponse ['{}'] -> failed (instanceId: {})", correlationId, item.key().getInstanceId());
}
}

@Named
@SuppressWarnings("resource")
public static class DispatcherDao extends AbstractDao {

private final ConcordObjectMapper objectMapper;
Expand Down Expand Up @@ -316,20 +323,20 @@ public List<ProcessQueueEntry> next(DSLContext tx, int offset, int limit) {

SelectJoinStep<Record14<UUID, OffsetDateTime, UUID, UUID, UUID, UUID, String, String, String, UUID, JSONB, JSONB, JSONB, String>> s =
tx.select(
q.INSTANCE_ID,
q.CREATED_AT,
q.PROJECT_ID,
orgIdField,
q.INITIATOR_ID,
q.PARENT_INSTANCE_ID,
q.REPO_PATH,
q.REPO_URL,
q.COMMIT_ID,
q.REPO_ID,
q.IMPORTS,
q.REQUIREMENTS,
q.EXCLUSIVE,
q.COMMIT_BRANCH)
q.INSTANCE_ID,
q.CREATED_AT,
q.PROJECT_ID,
orgIdField,
q.INITIATOR_ID,
q.PARENT_INSTANCE_ID,
q.REPO_PATH,
q.REPO_URL,
q.COMMIT_ID,
q.REPO_ID,
q.IMPORTS,
q.REQUIREMENTS,
q.EXCLUSIVE,
q.COMMIT_BRANCH)
.from(q);

s.where(q.CURRENT_STATUS.eq(ProcessStatus.ENQUEUED.toString())
Expand Down Expand Up @@ -369,36 +376,12 @@ public SecretReference getSecretReference(UUID repoId) {
}
}

private static final class Request {

private final WebSocketChannel channel;
private final ProcessRequest request;

private Request(WebSocketChannel channel, ProcessRequest request) {
this.channel = channel;
this.request = request;
}
private record Request(MessageChannel channel, ProcessRequest request) {
}

private static final class Match {

private final Request request;
private final ProcessQueueEntry response;

private Match(Request request, ProcessQueueEntry response) {
this.request = request;
this.response = response;
}
private record Match(Request request, ProcessQueueEntry response) {
}

private static final class SecretReference {

private final String orgName;
private final String secretName;

private SecretReference(String orgName, String secretName) {
this.orgName = orgName;
this.secretName = secretName;
}
public record SecretReference(String orgName, String secretName) {
}
}
Loading

0 comments on commit 560c16a

Please sign in to comment.