Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization: flush and lazy synchronize on downstream for propagation #25

Open
wants to merge 1 commit into
base: lx-base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,13 @@ public void onSyncReceive(Message syncMessage, int numUpstreams){
blockedAddresses.add(addressMatch);
}
blockedMessages.addAll(pendingMessages);
System.out.println("onSyncReceive mailbox " + self + " blockedAddresses size " + blockedAddresses.size() + " status " + status + " tid: " + Thread.currentThread().getName());
System.out.println("onSyncReceive mailbox " + self
+ " blockedAddresses " + Arrays.toString(blockedAddresses.toArray())
+ " address match " + addressMatch
+ " status " + status + " tid: " + Thread.currentThread().getName());
if(blockedAddresses.size() == numUpstreams && status == Status.RUNNABLE){
System.out.println("onSyncReceive Mailbox " + self() + " ready to block on SYNC_ONE " + " blocked size " + blockedAddresses.size() + " status " + status);
System.out.println("onSyncReceive Mailbox " + self() + " ready to block on SYNC_ONE "
+ " blocked size " + blockedAddresses.size() + " status " + status);
this.readyToBlock = true;
}
} catch (Exception e) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public class MailboxState {

public Address pendingStateRequest;

public MailboxState(FunctionActivation.Status status, boolean readyToBlock, Address pendingStateRequest) {
public MailboxState(FunctionActivation.Status status, boolean readyToBlock, Address pendingStateRequest ) {
this.status = status;
this.readyToBlock = readyToBlock;
this.pendingStateRequest = pendingStateRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,15 @@ public Message forward(Address to, Message message, ClassLoader loader, boolean
Object what = message.payload(messageFactory, loader);
Objects.requireNonNull(what);
Address lessor = message.target();
Address currentAddress = (in == null?message.target() : in.target());// calling from enqueue or not
Address currentAddress = ((in == null || (message.getHostActivation().self().toInternalAddress() != in.target().toInternalAddress()))? message.getHostActivation().self() : in.target());// calling from enqueue or not
if(currentAddress == null){
throw new FlinkRuntimeException("Forward a message with no host activation on message " + message + " to " + to + " in message: " + (in==null?"null":in) + " tid: " + Thread.currentThread().getName());
}
if(force){
envelope = messageFactory.from(message.source(), to, what, message.getPriority().priority, message.getPriority().laxity,
Message.MessageType.FORWARDED, message.getMessageId());
System.out.println("Register route through forward: " + currentAddress + " -> " + to + " tid: " + Thread.currentThread().getName());
System.out.println("RouteTracker forward activateRoute from " + currentAddress + " to " + to + " on message " + message + " in " + (in==null?"null":in) + " tid: " + Thread.currentThread().getName());
ownerFunctionGroup.get().getRouteTracker().activateRoute(currentAddress, to);
}
else{
Expand Down Expand Up @@ -283,8 +287,11 @@ public void send(Address to, Object what) {
callbackPendings.get(in.target().toInternalAddress()).add(pendingEnvelope);
}
else{
System.out.println("Register route through send: " + in.target() + " -> " + envelope.target() + " tid: " + Thread.currentThread().getName());
System.out.println("RouteTracker send activateRoute from " + in.target() + " to " + envelope.target() + " on message " + envelope + " tid: " + Thread.currentThread().getName());
ownerFunctionGroup.get().getRouteTracker().activateRoute(in.target(), envelope.target());
if(envelope.getMessageType() == Message.MessageType.NON_FORWARDING){
ownerFunctionGroup.get().onSendingCrictical(envelope, in.getHostActivation());
}
if (thisPartition.contains(envelope.target())) {
localSinkPendingQueue.add(envelope);
// drainLocalSinkOutput();
Expand Down Expand Up @@ -314,12 +321,15 @@ public void sendComplete(Address initiator, Message previous, Message dispatch){
}
if(dispatch != null){
if(dispatch.isForwarded()){
System.out.println("Register route through sendComplete (forward): " + " ia " + ia.toAddress() + " from " + initiator + " -> " + dispatch.target() + " tid: " + Thread.currentThread().getName());
System.out.println("RouteTracker sendComplete (forward) activateRoute from " + initiator + " to " + dispatch.target() + " ia " + ia.toAddress() + " on message " + dispatch + " tid: " + Thread.currentThread().getName());
ownerFunctionGroup.get().getRouteTracker().activateRoute(initiator, dispatch.target());
}
else{
System.out.println("RouteTracker sendComplete activateRoute from " + initiator + " to " + dispatch.target() + " ia " + ia.toAddress() + " on message " + dispatch + " tid: " + Thread.currentThread().getName());
ownerFunctionGroup.get().getRouteTracker().activateRoute(initiator, dispatch.target());
System.out.println("Register route through sendComplete: " + initiator + " -> " + dispatch.target() + " tid: " + Thread.currentThread().getName());
}
if(dispatch.getMessageType() == Message.MessageType.NON_FORWARDING){
ownerFunctionGroup.get().onSendingCrictical(dispatch, ownerFunctionGroup.get().getActivation(initiator));
}
if (thisPartition.contains(dispatch.target())) {
localSinkPendingQueue.add(dispatch);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package org.apache.flink.statefun.flink.core.functions;

import javafx.util.Pair;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.InternalAddress;
import org.apache.flink.statefun.sdk.utils.DataflowUtils;
import org.apache.flink.util.FlinkRuntimeException;
import scala.Int;

import java.util.*;
import java.util.stream.Collectors;
Expand All @@ -12,6 +15,9 @@ public class RouteTracker {
private final HashMap<InternalAddress, HashSet<InternalAddress>> lessorToLessees = new HashMap<>();
private final HashMap<InternalAddress, InternalAddress> lesseeToLessor = new HashMap<>();

private final HashMap<InternalAddress, HashSet<InternalAddress>> targetToSourceRoutes = new HashMap<>();
private final HashMap<InternalAddress, Pair<HashSet<InternalAddress>, HashSet<InternalAddress>>> targetToFlushedChannels = new HashMap<>();

// Add address mapping when scheduler forward a message
// A message sent from initiator to a PA on behalf of VA
public void activateRoute(Address initiator, Address pa){
Expand Down Expand Up @@ -48,6 +54,15 @@ public List<Address> getAllActiveRoutes (Address initiator){
return new ArrayList<>();
}

public List<Address> getAllActiveDownstreamRoutes (Address initiator){
if(routes.containsKey(initiator.toInternalAddress())){
return routes.get(initiator.toInternalAddress()).entrySet().stream()
.filter(pair-> pair.getValue() && (DataflowUtils.getFunctionId(pair.getKey().toAddress().type().getInternalType()) > DataflowUtils.getFunctionId(initiator.type().getInternalType())))
.map(kv->kv.getKey().toAddress()).collect(Collectors.toList());
}
return new ArrayList<>();
}


// get all PAs associated with a VA created by initiator
public Address[] getAllRoutes(Address initiator){
Expand Down Expand Up @@ -91,6 +106,16 @@ public Address getLessor(Address lessee){
return null;
}

public boolean ifLessorOf(Address lessor, Address lessee){
if(!lesseeToLessor.containsKey(lessee.toInternalAddress())) return false;
return lesseeToLessor.get(lessee.toInternalAddress()).equals(lessor.toInternalAddress());
}

public boolean ifLesseeOf(Address lessee, Address lessor){
if(!lessorToLessees.containsKey(lessor.toInternalAddress())) return false;
return lessorToLessees.get(lessor.toInternalAddress()).contains(lessee.toInternalAddress());
}

public boolean removeLessor(Address lessee){
return lesseeToLessor.remove(lessee.toInternalAddress()) != null;
}
Expand All @@ -104,4 +129,55 @@ public String getLesseeToLessorMap(){
return String.format("< %s >", lesseeToLessor.entrySet().stream()
.map(kv->kv.getKey().toAddress() + " -> " + kv.getValue().toAddress()).collect(Collectors.joining("|||")));
}

public void mergeTemporaryRoutingEntries(Address source, List<Address> targets){
for(Address target : targets){
targetToSourceRoutes.putIfAbsent(target.toInternalAddress(), new HashSet<>());
targetToSourceRoutes.get(target.toInternalAddress()).add(source.toInternalAddress());

}
}

public Map<InternalAddress, List<Address>> getTemporaryTargetToSourcesRoutes(){
return targetToSourceRoutes.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, kv->kv.getValue().stream().map(InternalAddress::toAddress).collect(Collectors.toList())));//.map(kv->new Map.Entry<>(kv.getKey().toAddress(), kv.getValue().stream().map(InternalAddress::toAddress).collect(Collectors.toList()))).coll;
}

public void clearTemporaryRoutingEntries(){
targetToSourceRoutes.clear();
}

public void onFlushDependencyReceive(Address target, List<Address> dependencies){
targetToFlushedChannels.putIfAbsent(target.toInternalAddress(), new Pair<>(new HashSet<>(), new HashSet<>()));
targetToFlushedChannels.get(target.toInternalAddress()).getKey().addAll(dependencies.stream().map(Address::toInternalAddress).collect(Collectors.toList()));
}

public void onFlushReceive(Address target, Address flushReceived){
targetToFlushedChannels.putIfAbsent(target.toInternalAddress(), new Pair<>(new HashSet<>(), new HashSet<>()));
targetToFlushedChannels.get(target.toInternalAddress()).getValue().add(flushReceived.toInternalAddress());
}

public boolean ifUpstreamFlushed(Address target){
Pair<HashSet<InternalAddress>, HashSet<InternalAddress>> flushes = targetToFlushedChannels.get(target.toInternalAddress());
if(flushes== null) {
System.out.println("Target address: " + target + " has no upstream flushes. tid: " + Thread.currentThread().getName());
return true;
}
return flushes.getKey().equals(flushes.getValue());
}

public void clearFlushDependencyReceived(Address target){
targetToFlushedChannels.remove(target.toInternalAddress());
}

public String getTargetToSourceRoutes() {
return targetToSourceRoutes.entrySet().stream().map(kv-> kv.getKey() + " -> " + Arrays.toString(kv.getValue().toArray())).collect(Collectors.joining("|||"));
}

public String getTargetToFlushedChannels() {
return targetToFlushedChannels.entrySet().stream()
.map(kv-> kv.getKey() + " -> (" + Arrays.toString(kv.getValue().getKey().toArray()) + " == " + Arrays.toString(kv.getValue().getValue().toArray()) + ")")
.collect(Collectors.joining(" ||| "));
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package org.apache.flink.statefun.flink.core.functions;

import org.apache.flink.statefun.flink.core.functions.scheduler.LesseeSelector;
import org.apache.flink.statefun.flink.core.functions.scheduler.RandomLesseeSelector;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.InternalAddress;

import java.util.*;
import java.util.stream.Collectors;

// Class to hold all the information related to state aggregation
public class StateAggregationInfo {
// private int numUpstreams;
//private int numPartialStatesReceived;
private Address syncSource;
private Set<InternalAddress> lessees;
private ArrayList<Address> partitionedAddresses;
private LesseeSelector lesseeSelector;
private HashSet<InternalAddress> distinctPartialStateSources;
private Set<InternalAddress> expectedPartialStateSources;
private HashSet<InternalAddress> distinctCriticalMessages;
// TODO: Assigned from sync recv
public Set<InternalAddress> expectedCriticalMessageSources;
private Boolean autoblocking;
private Boolean pendingRequestServed;


public StateAggregationInfo(ReusableContext context) {
// this.numUpstreams = numUpstreams;
this.syncSource = null;
this.lessees = new HashSet<>();
//this.numPartialStatesReceived = 0;
this.partitionedAddresses = null;
this.lesseeSelector = new RandomLesseeSelector(context.getPartition());
this.distinctPartialStateSources = new HashSet<>();
this.expectedPartialStateSources = new HashSet<>();
this.distinctCriticalMessages = new HashSet<>();
this.expectedCriticalMessageSources = new HashSet<>();
this.autoblocking = null;
this.pendingRequestServed = true;
}

public void resetInfo() {
this.distinctPartialStateSources.clear();
this.distinctCriticalMessages.clear();
this.expectedCriticalMessageSources.clear();
this.expectedPartialStateSources.clear();
this.lessees.clear();
this.autoblocking = null;
this.syncSource = null;
}

public Address getSyncSource() {
return this.syncSource;
}

public void setSyncSource(Address syncSource) {
this.syncSource = syncSource;
}
// public int getNumUpstreams() {
// return this.numUpstreams;
// }

public void incrementNumPartialStatesReceived(InternalAddress address) {
this.distinctPartialStateSources.add(address);
//this.numPartialStatesReceived += 1;
}

public void incrementNumCriticalMessagesReceived(InternalAddress address) {
this.distinctCriticalMessages.add(address);
}

public boolean areAllPartialStatesReceived() {
//return (this.distinctPartialStateSources.size() == lesseeSelector.getBroadcastAddresses(lessor).size());
return (this.distinctPartialStateSources.size() == expectedPartialStateSources.size());
}

public boolean areAllCriticalMessagesReceived() {
return (this.distinctCriticalMessages.size() == expectedCriticalMessageSources.size());
}

public void setExpectedPartialStateSources(Set<InternalAddress> sources) {
expectedPartialStateSources = sources;
}

public Set<Address> getExpectedPartialStateSources() {
return expectedPartialStateSources.stream().map(x -> x.address).collect(Collectors.toSet());
}

public Set<Address> getExpectedCriticalMessage() {
return expectedCriticalMessageSources.stream().map(x -> x.address).collect(Collectors.toSet());
}

public void setExpectedCriticalMessageSources(Set<InternalAddress> sources) {
expectedCriticalMessageSources = sources;
}

public void addLessee(Address lessee) {
lessees.add(new InternalAddress(lessee, lessee.type().getInternalType()));
}

public List<Address> getLessees() {
return lessees.stream().map(ia -> ia.address).collect(Collectors.toList());
}

public boolean hasLessee(Address lessee) {
return lessees.contains(new InternalAddress(lessee, lessee.type().getInternalType()));
}

public void setAutoblocking(Boolean blocking) {
autoblocking = blocking;
}

public Boolean ifAutoblocking() {
return autoblocking;
}

public Boolean getPendingRequestServed() {
return pendingRequestServed;
}

public void setPendingRequestServed(Boolean requestServed) {
pendingRequestServed = requestServed;
}

// TODO: Use this function at all context forwards. Need to capture the context.forward() call
public void addPartition(Address partition) {
this.partitionedAddresses.add(partition);
}

public ArrayList<Address> getPartitionedAddresses() {
//return this.partitionedAddresses;
return lesseeSelector.getBroadcastAddresses(syncSource);
}
public HashSet<InternalAddress> getDistinctPartialStateSources() {
return distinctPartialStateSources;
}

public void setDistinctPartialStateSources(HashSet<InternalAddress> distinctPartialStateSources) {
this.distinctPartialStateSources = distinctPartialStateSources;
}

public HashSet<InternalAddress> getDistinctCriticalMessages() {
return distinctCriticalMessages;
}

public void setDistinctCriticalMessages(HashSet<InternalAddress> distinctCriticalMessages) {
this.distinctCriticalMessages = distinctCriticalMessages;
}

public Set<InternalAddress> getExpectedCriticalMessageSources() {
return expectedCriticalMessageSources;
}

@Override
public String toString() {
return String.format("StateAggregationInfo numPartialStatesReceived %d lessor %s partitionedAddresses %s hash %d", distinctPartialStateSources.size(),
(syncSource == null ? "null" : syncSource.toString()), (partitionedAddresses == null ? "null" : Arrays.toString(partitionedAddresses.toArray())), this.hashCode());
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.apache.flink.statefun.flink.core.functions;

import javafx.util.Pair;
import org.apache.flink.statefun.sdk.Address;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class SyncReplyState implements Serializable {
private HashMap<Pair<String, Address>, byte[]> stateMap;
private List<Address> targetList;

public SyncReplyState(HashMap<Pair<String, Address>, byte[]> map, List<Address> list){
stateMap = map;
targetList = list;
}

public HashMap<Pair<String, Address>, byte[]> getStateMap(){
return stateMap;
}

public List<Address> getTargetList(){
return targetList;
}
}
Loading