Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@

# Using Field to exclude the condition in serialization if it's a callable
condition_function: Callable[[BaseChatMessage], bool] | None = Field(default=None, exclude=True)
activation_group: str = Field(default="")
"""Group identifier for forward dependencies.

When multiple edges point to the same target node, they are grouped by this field.
This allows distinguishing between different cycles or dependency patterns.

Example: In a graph containing a cycle like A->B->C->B, the two edges pointing to B (A->B and C->B)
can be in different activation groups to control how B is activated.
Defaults to the target node name if not specified.
"""
activation_condition: Literal["all", "any"] = "all"
"""Determines how forward dependencies within the same activation_group are evaluated.

- "all": All edges in this activation group must be satisfied before the target node can execute
- "any": Any single edge in this activation group being satisfied allows the target node to execute

This is used to handle complex dependency patterns in cyclic graphs where multiple
paths can lead to the same target node.
"""

@model_validator(mode="after")
def _validate_condition(self) -> "DiGraphEdge":
Expand All @@ -59,6 +78,11 @@
# For serialization purposes, we'll set the condition to None
# when storing as a pydantic model/dict
object.__setattr__(self, "condition", None)

# Set activation_group to target if not already set
if not self.activation_group:
self.activation_group = self.target

return self

def check_condition(self, message: BaseChatMessage) -> bool:
Expand Down Expand Up @@ -112,8 +136,7 @@
parents: Dict[str, List[str]] = {node: [] for node in self.nodes}
for node in self.nodes.values():
for edge in node.edges:
if edge.target != node.name:
parents[edge.target].append(node.name)
parents[edge.target].append(node.name)
return parents

def get_start_nodes(self) -> Set[str]:
Expand Down Expand Up @@ -206,8 +229,79 @@
if has_condition and has_unconditioned:
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")

# Validate activation conditions across all edges in the graph
self._validate_activation_conditions()

self._has_cycles = self.has_cycles_with_exit()

def _validate_activation_conditions(self) -> None:
"""Validate that all edges pointing to the same target node have consistent activation_condition values.

Raises:
ValueError: If edges pointing to the same target have different activation_condition values
"""
target_activation_conditions: Dict[str, Dict[str, str]] = {} # target_node -> {activation_group -> condition}

for node in self.nodes.values():
for edge in node.edges:
target = edge.target # The target node this edge points to
activation_group = edge.activation_group

if target not in target_activation_conditions:
target_activation_conditions[target] = {}

if activation_group in target_activation_conditions[target]:
if target_activation_conditions[target][activation_group] != edge.activation_condition:
# Find the source node that has the conflicting condition
conflicting_source = self._find_edge_source_by_target_and_group(
target, activation_group, target_activation_conditions[target][activation_group]
)
raise ValueError(
f"Conflicting activation conditions for target '{target}' group '{activation_group}': "
f"'{target_activation_conditions[target][activation_group]}' (from node '{conflicting_source}') "
f"and '{edge.activation_condition}' (from node '{node.name}')"
)
else:
target_activation_conditions[target][activation_group] = edge.activation_condition

def _find_edge_source_by_target_and_group(
self, target: str, activation_group: str, activation_condition: str
) -> str:
"""Find the source node that has an edge pointing to the given target with the given activation_group and activation_condition."""
for node_name, node in self.nodes.items():
for edge in node.edges:
if (
edge.target == target
and edge.activation_group == activation_group
and edge.activation_condition == activation_condition
):
return node_name
return "unknown"

Check warning on line 279 in python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py#L279

Added line #L279 was not covered by tests

def get_remaining_map(self) -> Dict[str, Dict[str, int]]:
"""Get the remaining map that tracks how many edges point to each target node with each activation group.

Returns:
Dictionary mapping target nodes to their activation groups and remaining counts
"""

remaining_map: Dict[str, Dict[str, int]] = {}

for node in self.nodes.values():
for edge in node.edges:
target = edge.target
activation_group = edge.activation_group

if target not in remaining_map:
remaining_map[target] = {}

if activation_group not in remaining_map[target]:
remaining_map[target][activation_group] = 0

remaining_map[target][activation_group] += 1

return remaining_map


class GraphFlowManagerState(BaseGroupChatManagerState):
"""Tracks active execution state for DAG-based execution."""
Expand Down Expand Up @@ -254,18 +348,51 @@
self._parents = graph.get_parents()
# Lookup table for outgoing edges for each node.
self._edges: Dict[str, List[DiGraphEdge]] = {n: node.edges for n, node in graph.nodes.items()}
# Activation lookup table for each node.
self._activation: Dict[str, Literal["any", "all"]] = {n: node.activation for n, node in graph.nodes.items()}

# Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node
self._build_lookup_tables(graph)

# Track which activation groups were triggered for each node
self._triggered_activation_groups: Dict[str, Set[str]] = {}
# === Mutable states for the graph execution ===
# Count the number of remaining parents to activate each node.
self._remaining: Counter[str] = Counter({n: len(p) for n, p in self._parents.items()})
# Lookup table for nodes that have been enqueued through an any activation.
# This is used to prevent re-adding the same node multiple times.
self._enqueued_any: Dict[str, bool] = {n: False for n in graph.nodes}
self._remaining: Dict[str, Counter[str]] = {
target: Counter(groups) for target, groups in graph.get_remaining_map().items()
}
# cache for remaining
self._origin_remaining: Dict[str, Dict[str, int]] = {
target: Counter(groups) for target, groups in self._remaining.items()
}

# Ready queue for nodes that are ready to execute, starting with the start nodes.
self._ready: Deque[str] = deque([n for n in graph.get_start_nodes()])

def _build_lookup_tables(self, graph: DiGraph) -> None:
"""Build activation and enqueued_any lookup tables by collecting all edges and grouping by target node.

Args:
graph: The directed graph
"""
self._activation: Dict[str, Dict[str, Literal["any", "all"]]] = {}
self._enqueued_any: Dict[str, Dict[str, bool]] = {}

for node in graph.nodes.values():
for edge in node.edges:
target = edge.target
activation_group = edge.activation_group

# Build activation lookup
if target not in self._activation:
self._activation[target] = {}
if activation_group not in self._activation[target]:
self._activation[target][activation_group] = edge.activation_condition

# Build enqueued_any lookup
if target not in self._enqueued_any:
self._enqueued_any[target] = {}
if activation_group not in self._enqueued_any[target]:
self._enqueued_any[target][activation_group] = False

async def update_message_thread(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> None:
await super().update_message_thread(messages)

Expand All @@ -282,28 +409,65 @@
# Use the new check_condition method that handles both string and callable conditions
if not edge.check_condition(message):
continue
if self._activation[edge.target] == "all":
self._remaining[edge.target] -= 1
if self._remaining[edge.target] == 0:

target = edge.target
activation_group = edge.activation_group

if self._activation[target][activation_group] == "all":
self._remaining[target][activation_group] -= 1
if self._remaining[target][activation_group] == 0:
# If all parents are done, add to the ready queue.
self._ready.append(edge.target)
self._ready.append(target)
# Track which activation group was triggered
self._save_triggered_activation_group(target, activation_group)
else:
# If activation is any, add to the ready queue if not already enqueued.
if not self._enqueued_any[edge.target]:
self._ready.append(edge.target)
self._enqueued_any[edge.target] = True
if not self._enqueued_any[target][activation_group]:
self._ready.append(target)
self._enqueued_any[target][activation_group] = True

Check warning on line 427 in python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py#L425-L427

Added lines #L425 - L427 were not covered by tests
# Track which activation group was triggered
self._save_triggered_activation_group(target, activation_group)

Check warning on line 429 in python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py#L429

Added line #L429 was not covered by tests

def _save_triggered_activation_group(self, target: str, activation_group: str) -> None:
"""Save which activation group was triggered for a target node.

Args:
target: The target node that was triggered
activation_group: The activation group that caused the trigger
"""
if target not in self._triggered_activation_groups:
self._triggered_activation_groups[target] = set()
self._triggered_activation_groups[target].add(activation_group)

def _reset_triggered_activation_groups(self, speaker: str) -> None:
"""Reset the bookkeeping for the specific activation groups that were triggered for a speaker.

Args:
speaker: The speaker node to reset activation groups for
"""
if speaker not in self._triggered_activation_groups:
return

for activation_group in self._triggered_activation_groups[speaker]:
if self._activation[speaker][activation_group] == "any":
self._enqueued_any[speaker][activation_group] = False

Check warning on line 453 in python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py#L453

Added line #L453 was not covered by tests
else:
# Reset the remaining count for this activation group using the graph's original count
if speaker in self._remaining and activation_group in self._remaining[speaker]:
self._remaining[speaker][activation_group] = self._origin_remaining[speaker][activation_group]

# Clear the triggered activation groups for this speaker
self._triggered_activation_groups[speaker].clear()

async def select_speaker(self, thread: Sequence[BaseAgentEvent | BaseChatMessage]) -> List[str]:
# Drain the ready queue for the next set of speakers.
speakers: List[str] = []
while self._ready:
speaker = self._ready.popleft()
speakers.append(speaker)
# Reset the bookkeeping for the node that were selected.
if self._activation[speaker] == "any":
self._enqueued_any[speaker] = False
else:
self._remaining[speaker] = len(self._parents[speaker])

# Reset the bookkeeping for the specific activation groups that were triggered
self._reset_triggered_activation_groups(speaker)

# If there are no speakers, trigger the stop agent.
if not speakers:
Expand All @@ -319,7 +483,7 @@
state = {
"message_thread": [message.dump() for message in self._message_thread],
"current_turn": self._current_turn,
"remaining": dict(self._remaining),
"remaining": {target: dict(counter) for target, counter in self._remaining.items()},
"enqueued_any": dict(self._enqueued_any),
"ready": list(self._ready),
}
Expand All @@ -329,7 +493,7 @@
"""Restore execution state from saved data."""
self._message_thread = [self._message_factory.create(msg) for msg in state["message_thread"]]
self._current_turn = state["current_turn"]
self._remaining = Counter(state["remaining"])
self._remaining = {target: Counter(groups) for target, groups in state["remaining"].items()}
self._enqueued_any = state["enqueued_any"]
self._ready = deque(state["ready"])

Expand All @@ -339,8 +503,8 @@
self._message_thread.clear()
if self._termination_condition:
await self._termination_condition.reset()
self._remaining = Counter({n: len(p) for n, p in self._parents.items()})
self._enqueued_any = {n: False for n in self._graph.nodes}
self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()}
self._enqueued_any = {n: {g: False for g in self._enqueued_any[n]} for n in self._enqueued_any}
self._ready = deque([n for n in self._graph.get_start_nodes()])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,23 @@ class DiGraphBuilder:
>>> builder.add_edge(agent_b, agent_a, condition=lambda msg: "loop" in msg.to_model_text())
>>> # Add exit condition to break the loop
>>> builder.add_edge(agent_b, agent_c, condition=lambda msg: "loop" not in msg.to_model_text())

Example — Loop with multiple paths to the same node: A → B → C → B:
>>> builder = GraphBuilder()
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> builder.add_edge(agent_a, agent_b)
>>> builder.add_edge(agent_b, agent_c)
>>> builder.add_edge(agent_c, agent_b, activation_group="loop_back")

Example — Loop with multiple paths to the same node with any activation condition: A → B → (C1, C2) → B → E(exit):
>>> builder = GraphBuilder()
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c1).add_node(agent_c2).add_node(agent_e)
>>> builder.add_edge(agent_a, agent_b)
>>> builder.add_edge(agent_b, agent_c1)
>>> builder.add_edge(agent_b, agent_c2)
>>> builder.add_edge(agent_b, agent_e, condition="exit")
>>> builder.add_edge(agent_c1, agent_b, activation_group="loop_back_group", activation_condition="any")
>>> builder.add_edge(agent_c2, agent_b, activation_group="loop_back_group", activation_condition="any")
"""

def __init__(self) -> None:
Expand All @@ -97,6 +114,8 @@ def add_edge(
source: Union[str, ChatAgent],
target: Union[str, ChatAgent],
condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None,
activation_group: Optional[str] = None,
activation_condition: Optional[Literal["all", "any"]] = None,
) -> "DiGraphBuilder":
"""Add a directed edge from source to target, optionally with a condition.

Expand All @@ -120,8 +139,18 @@ def add_edge(
raise ValueError(f"Source node '{source_name}' must be added before adding an edge.")
if target_name not in self.nodes:
raise ValueError(f"Target node '{target_name}' must be added before adding an edge.")

self.nodes[source_name].edges.append(DiGraphEdge(target=target_name, condition=condition))
if activation_group is None:
activation_group = target_name
if activation_condition is None:
activation_condition = "all"
self.nodes[source_name].edges.append(
DiGraphEdge(
target=target_name,
condition=condition,
activation_group=activation_group,
activation_condition=activation_condition,
)
)
return self

def add_conditional_edges(
Expand Down
Loading
Loading