Skip to content

Graph

Graph

Bases: BaseModel

Create and execute a graph defined by a list of edges.

Set the required State and Shared classes via the Generic Typing Parameters. Because of variance its possible to use nodes, that use more general State and Shared classes (ancestors) as the Generic Typing Parameters.

The edges are defined as a list of tuples, where the first element is the source node and the second element reveals the next node.

Parameters:

Name Type Description Default
edges

A list of edges of compatible nodes that build the graph

required
instant_edge

A list of edges of compatible nodes that run parallel to there source node

required
Source code in src/edgygraph/graph.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class Graph[T: State = State, S: Shared = Shared](BaseModel):
    """
    Create and execute a graph defined by a list of edges.

    Set the required State and Shared classes via the Generic Typing Parameters.
    Because of variance its possible to use nodes, that use more general State and Shared classes (ancestors) as the Generic Typing Parameters. 

    The edges are defined as a list of tuples, where the first element is the source node and the second element reveals the next node.

    Args:
        edges: A list of edges of compatible nodes that build the graph
        instant_edge: A list of edges of compatible nodes that run parallel to there source node
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

    edges: list[Edge[T, S]] = Field(default_factory=list[Edge[T, S]])
    instant_edges: list[Edge[T, S]] = Field(default_factory=list[Edge[T, S]])

    _edge_index: dict[Node[T, S] | Type[START], list[NextType[T, S]]] = defaultdict(list[NextType[T, S]])
    _instant_edge_index: dict[Node[T, S] | Type[START], list[NextType[T, S]]] = defaultdict(list[NextType[T, S]])


    def model_post_init(self, _) -> None:
        """
        Index the edges by source node
        """
        self._edge_index = self.index_edges(self.edges)
        self._instant_edge_index = self.index_edges(self.instant_edges)


    def index_edges(self, edges: list[Edge[T, S]]) -> dict[Node[T, S] | Type[START], list[NextType[T, S]]]:
        """
        Index the edges by source node

        Args:
           edges: The edges to index

        Returns:
            A mapping from source node (or START) to the next objects of the edge

        """

        edges_index: dict[Node[T, S] | Type[START], list[NextType[T, S]]] = defaultdict(list[NextType[T, S]])

        for edge in edges:
            sources = edge[0]
            if isinstance(sources, list):
                for source in sources:
                    edges_index[source].append(edge[1])
            else:
                edges_index[sources].append(edge[1])

        return edges_index




    async def __call__(self, state: T, shared: S) -> Tuple[T, S]:
        """
        Execute the graph based on the edges

        Args:
            state: State of the first generic type of the graph or a subtype
            shared: Shared of the second generic type of the graph or a subtype

        Returns:
            New State instance and the same Shared instance
        """

        current_nodes: list[Node[T, S]] | list[Node[T, S] | Type[START]] = [START]

        while True:

            next_nodes: list[Node[T, S]] = await self.get_next_nodes(state, shared, current_nodes, self._edge_index)

            if not next_nodes:
                break # END


            current_instant_nodes: list[Node[T, S]] = next_nodes.copy()
            while True:

                current_instant_nodes = await self.get_next_nodes(state, shared, current_instant_nodes, self._instant_edge_index)

                logger.debug("CURRENT INSTANT NODES: %s", current_instant_nodes)

                if not current_instant_nodes:
                    break

                next_nodes.extend(current_instant_nodes)

            logger.debug("NEXT NODES: %s", next_nodes)

            parallel_tasks: list[Callable[[T, S], Coroutine[None, None, None]]] = []


            # Extract the run function of the nodes
            for next_node in next_nodes:

                parallel_tasks.append(next_node.run)


            # Run parallel
            result_states: list[T] = []

            async with asyncio.TaskGroup() as tg:
                for task in parallel_tasks:

                    state_copy: T = state.model_copy(deep=True)
                    result_states.append(state_copy)

                    tg.create_task(task(state_copy, shared))

            state = self.merge_states(state, result_states)

            current_nodes = next_nodes


        return state, shared

    async def get_next_nodes(self, state: T, shared: S, current_nodes: list[Node[T, S]] | list[Node[T, S] | Type[START]], edge_index: dict[Node[T, S] | Type[START], list[NextType[T, S]]]) -> list[Node[T, S]]:
        """
        Args:
            state: The current state
            shared: The shared state
            current_nodes: The current nodes

        Returns:
           The list of the next nodes to run based on the current nodes and edges.
           If an edge is a callable, it will be called with the current state and shared state.
        """


        next_types: list[NextType[T, S]] = []

        for current_node in current_nodes:

            # Find the edge corresponding to the current node
            next_types.extend(edge_index[current_node])


        next_nodes: list[Node[T, S]] = []
        for next in next_types:

            next = next

            if next is END:
                continue

            if isinstance(next, Callable):
                res = next(state, shared) #type:ignore (its not an END!)
                if inspect.isawaitable(res):
                    res = await res # for awaitables

                if isinstance(res, Node):
                    next_nodes.append(res)

            else:
                next_nodes.append(next)

        return next_nodes


    def merge_states(self, current_state: T, result_states: list[T]) -> T:
        """
        Merges the result states into the current state.
        First the changes are calculated for each result state.
        Then the changes are checked for conflicts.
        If there are conflicts, a ChangeConflictException is raised.
        The changes are applied in the order of the result states list.

        Args:
            current_state: The current state
            result_states: The result states

        Returns:
            The new merged State instance.

        Raises:
            ChangeConflictException: If there are conflicts in the changes.
        """

        result_dicts = [state.model_dump() for state in result_states]
        current_dict = current_state.model_dump()
        state_class = type(current_state)

        changes_list: list[dict[str, Change]] = []

        for result_dict in result_dicts:

            changes_list.append(Diff.recursive_diff(current_dict, result_dict))

        logger.debug(f"CHANGES: %s", changes_list)

        conflicts = Diff.find_conflicts(changes_list)

        if conflicts:
            raise ChangeConflictException(f"Conflicts detected after parallel execution: {conflicts}")

        for changes in changes_list:
            Diff.apply_changes(current_dict, changes)

        state: T = state_class.model_validate(current_dict)

        logger.debug("NEW STATE: %s", state)

        return state

model_post_init(_)

Index the edges by source node

Source code in src/edgygraph/graph.py
44
45
46
47
48
49
def model_post_init(self, _) -> None:
    """
    Index the edges by source node
    """
    self._edge_index = self.index_edges(self.edges)
    self._instant_edge_index = self.index_edges(self.instant_edges)

index_edges(edges)

Index the edges by source node

Parameters:

Name Type Description Default
edges list[Edge[T, S]]

The edges to index

required

Returns:

Type Description
dict[Node[T, S] | Type[START], list[NextType[T, S]]]

A mapping from source node (or START) to the next objects of the edge

Source code in src/edgygraph/graph.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def index_edges(self, edges: list[Edge[T, S]]) -> dict[Node[T, S] | Type[START], list[NextType[T, S]]]:
    """
    Index the edges by source node

    Args:
       edges: The edges to index

    Returns:
        A mapping from source node (or START) to the next objects of the edge

    """

    edges_index: dict[Node[T, S] | Type[START], list[NextType[T, S]]] = defaultdict(list[NextType[T, S]])

    for edge in edges:
        sources = edge[0]
        if isinstance(sources, list):
            for source in sources:
                edges_index[source].append(edge[1])
        else:
            edges_index[sources].append(edge[1])

    return edges_index

__call__(state, shared) async

Execute the graph based on the edges

Parameters:

Name Type Description Default
state T

State of the first generic type of the graph or a subtype

required
shared S

Shared of the second generic type of the graph or a subtype

required

Returns:

Type Description
Tuple[T, S]

New State instance and the same Shared instance

Source code in src/edgygraph/graph.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
async def __call__(self, state: T, shared: S) -> Tuple[T, S]:
    """
    Execute the graph based on the edges

    Args:
        state: State of the first generic type of the graph or a subtype
        shared: Shared of the second generic type of the graph or a subtype

    Returns:
        New State instance and the same Shared instance
    """

    current_nodes: list[Node[T, S]] | list[Node[T, S] | Type[START]] = [START]

    while True:

        next_nodes: list[Node[T, S]] = await self.get_next_nodes(state, shared, current_nodes, self._edge_index)

        if not next_nodes:
            break # END


        current_instant_nodes: list[Node[T, S]] = next_nodes.copy()
        while True:

            current_instant_nodes = await self.get_next_nodes(state, shared, current_instant_nodes, self._instant_edge_index)

            logger.debug("CURRENT INSTANT NODES: %s", current_instant_nodes)

            if not current_instant_nodes:
                break

            next_nodes.extend(current_instant_nodes)

        logger.debug("NEXT NODES: %s", next_nodes)

        parallel_tasks: list[Callable[[T, S], Coroutine[None, None, None]]] = []


        # Extract the run function of the nodes
        for next_node in next_nodes:

            parallel_tasks.append(next_node.run)


        # Run parallel
        result_states: list[T] = []

        async with asyncio.TaskGroup() as tg:
            for task in parallel_tasks:

                state_copy: T = state.model_copy(deep=True)
                result_states.append(state_copy)

                tg.create_task(task(state_copy, shared))

        state = self.merge_states(state, result_states)

        current_nodes = next_nodes


    return state, shared

get_next_nodes(state, shared, current_nodes, edge_index) async

Parameters:

Name Type Description Default
state T

The current state

required
shared S

The shared state

required
current_nodes list[Node[T, S]] | list[Node[T, S] | Type[START]]

The current nodes

required

Returns:

Type Description
list[Node[T, S]]

The list of the next nodes to run based on the current nodes and edges.

list[Node[T, S]]

If an edge is a callable, it will be called with the current state and shared state.

Source code in src/edgygraph/graph.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
async def get_next_nodes(self, state: T, shared: S, current_nodes: list[Node[T, S]] | list[Node[T, S] | Type[START]], edge_index: dict[Node[T, S] | Type[START], list[NextType[T, S]]]) -> list[Node[T, S]]:
    """
    Args:
        state: The current state
        shared: The shared state
        current_nodes: The current nodes

    Returns:
       The list of the next nodes to run based on the current nodes and edges.
       If an edge is a callable, it will be called with the current state and shared state.
    """


    next_types: list[NextType[T, S]] = []

    for current_node in current_nodes:

        # Find the edge corresponding to the current node
        next_types.extend(edge_index[current_node])


    next_nodes: list[Node[T, S]] = []
    for next in next_types:

        next = next

        if next is END:
            continue

        if isinstance(next, Callable):
            res = next(state, shared) #type:ignore (its not an END!)
            if inspect.isawaitable(res):
                res = await res # for awaitables

            if isinstance(res, Node):
                next_nodes.append(res)

        else:
            next_nodes.append(next)

    return next_nodes

merge_states(current_state, result_states)

Merges the result states into the current state. First the changes are calculated for each result state. Then the changes are checked for conflicts. If there are conflicts, a ChangeConflictException is raised. The changes are applied in the order of the result states list.

Parameters:

Name Type Description Default
current_state T

The current state

required
result_states list[T]

The result states

required

Returns:

Type Description
T

The new merged State instance.

Raises:

Type Description
ChangeConflictException

If there are conflicts in the changes.

Source code in src/edgygraph/graph.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def merge_states(self, current_state: T, result_states: list[T]) -> T:
    """
    Merges the result states into the current state.
    First the changes are calculated for each result state.
    Then the changes are checked for conflicts.
    If there are conflicts, a ChangeConflictException is raised.
    The changes are applied in the order of the result states list.

    Args:
        current_state: The current state
        result_states: The result states

    Returns:
        The new merged State instance.

    Raises:
        ChangeConflictException: If there are conflicts in the changes.
    """

    result_dicts = [state.model_dump() for state in result_states]
    current_dict = current_state.model_dump()
    state_class = type(current_state)

    changes_list: list[dict[str, Change]] = []

    for result_dict in result_dicts:

        changes_list.append(Diff.recursive_diff(current_dict, result_dict))

    logger.debug(f"CHANGES: %s", changes_list)

    conflicts = Diff.find_conflicts(changes_list)

    if conflicts:
        raise ChangeConflictException(f"Conflicts detected after parallel execution: {conflicts}")

    for changes in changes_list:
        Diff.apply_changes(current_dict, changes)

    state: T = state_class.model_validate(current_dict)

    logger.debug("NEW STATE: %s", state)

    return state

ChangeTypes

Bases: StrEnum

Enum for the types of changes that can be made to a State.

Source code in src/edgygraph/graph.py
231
232
233
234
235
236
237
238
class ChangeTypes(StrEnum):
    """
    Enum for the types of changes that can be made to a State.
    """

    ADDED = auto()
    REMOVED = auto()
    UPDATED = auto()

Change

Bases: RichReprMixin, BaseModel

Represents a change made to a State.

Source code in src/edgygraph/graph.py
240
241
242
243
244
245
246
247
class Change(RichReprMixin, BaseModel):
    """
    Represents a change made to a State.
    """

    type: ChangeTypes
    old: Any
    new: Any

Diff

Utility class for computing differences between states.

Source code in src/edgygraph/graph.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class Diff:
    """
    Utility class for computing differences between states.
    """


    @classmethod
    def find_conflicts(cls, changes: list[dict[str, Change]]) -> dict[str, list[Change]]:
        """
        Finds conflicts in a list of changes.

        Args:
           changes: A list of dictionaries representing changes to a state.
        """

        if len(changes) <= 1:
            return {}

        counts = Counter(key for d in changes for key in d)

        duplicate_keys = [k for k, count in counts.items() if count > 1]

        conflicts: dict[str, list[Change]] = {}        
        for key in duplicate_keys:
            conflicts[key] = [d[key] for d in changes if key in d]

        return conflicts


    @classmethod
    def recursive_diff(cls, old: Any, new: Any, path: str = "") -> dict[str, Change]:
        """
        Recursively computes the differences between two dictionaries.


        Args:
            old: Part of the old dictionary.
            new: Part of the new dictionary.
            path: The current path of the parts in the full dictionary, seperated with dots.

        Returns:
            A mapping of the path to the changes directly on that level.
        """

        changes: dict[str, Change] = {}

        if isinstance(old, dict) and isinstance(new, dict):
            all_keys: set[str] = set(old.keys()) | set(new.keys()) #type: ignore

            for key in all_keys:
                current_path: str = f"{path}.{key}" if path else key

                if key in old and not key in new:
                    changes[current_path] = Change(type=ChangeTypes.REMOVED, old=old[key], new=None)
                elif key in new and not key in old:
                    changes[current_path] = Change(type=ChangeTypes.ADDED, old=None, new=new[key])
                else:
                    sub_changes = cls.recursive_diff(old[key], new[key], current_path)
                    changes.update(sub_changes)

        elif old != new:
            changes[path] = Change(type=ChangeTypes.UPDATED, old=old, new=new)

        return changes


    @classmethod
    def apply_changes(cls, target: dict[str, Any], changes: dict[str, Change]) -> None:
        """
        Applies a set of changes to the target dictionary.


        Args:
            target: The dictionary to apply the changes to.
            changes: A mapping of paths, separated by dots, to changes. The changes are applied in the dictionary on that level.
        """

        for path, change in changes.items():
            parts = path.split(".")
            cursor = target

            # Navigate down the dictionary
            for part in parts[:-1]:
                if part not in cursor:
                    cursor[part] = {} # If the path was created because of ADDED
                cursor = cursor[part]

            last_key = parts[-1]

            if change.type == ChangeTypes.REMOVED:
                if last_key in cursor:
                    del cursor[last_key]
            else:
                # UPDATED or ADDED
                cursor[last_key] = change.new

find_conflicts(changes) classmethod

Finds conflicts in a list of changes.

Parameters:

Name Type Description Default
changes list[dict[str, Change]]

A list of dictionaries representing changes to a state.

required
Source code in src/edgygraph/graph.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@classmethod
def find_conflicts(cls, changes: list[dict[str, Change]]) -> dict[str, list[Change]]:
    """
    Finds conflicts in a list of changes.

    Args:
       changes: A list of dictionaries representing changes to a state.
    """

    if len(changes) <= 1:
        return {}

    counts = Counter(key for d in changes for key in d)

    duplicate_keys = [k for k, count in counts.items() if count > 1]

    conflicts: dict[str, list[Change]] = {}        
    for key in duplicate_keys:
        conflicts[key] = [d[key] for d in changes if key in d]

    return conflicts

recursive_diff(old, new, path='') classmethod

Recursively computes the differences between two dictionaries.

Parameters:

Name Type Description Default
old Any

Part of the old dictionary.

required
new Any

Part of the new dictionary.

required
path str

The current path of the parts in the full dictionary, seperated with dots.

''

Returns:

Type Description
dict[str, Change]

A mapping of the path to the changes directly on that level.

Source code in src/edgygraph/graph.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@classmethod
def recursive_diff(cls, old: Any, new: Any, path: str = "") -> dict[str, Change]:
    """
    Recursively computes the differences between two dictionaries.


    Args:
        old: Part of the old dictionary.
        new: Part of the new dictionary.
        path: The current path of the parts in the full dictionary, seperated with dots.

    Returns:
        A mapping of the path to the changes directly on that level.
    """

    changes: dict[str, Change] = {}

    if isinstance(old, dict) and isinstance(new, dict):
        all_keys: set[str] = set(old.keys()) | set(new.keys()) #type: ignore

        for key in all_keys:
            current_path: str = f"{path}.{key}" if path else key

            if key in old and not key in new:
                changes[current_path] = Change(type=ChangeTypes.REMOVED, old=old[key], new=None)
            elif key in new and not key in old:
                changes[current_path] = Change(type=ChangeTypes.ADDED, old=None, new=new[key])
            else:
                sub_changes = cls.recursive_diff(old[key], new[key], current_path)
                changes.update(sub_changes)

    elif old != new:
        changes[path] = Change(type=ChangeTypes.UPDATED, old=old, new=new)

    return changes

apply_changes(target, changes) classmethod

Applies a set of changes to the target dictionary.

Parameters:

Name Type Description Default
target dict[str, Any]

The dictionary to apply the changes to.

required
changes dict[str, Change]

A mapping of paths, separated by dots, to changes. The changes are applied in the dictionary on that level.

required
Source code in src/edgygraph/graph.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
@classmethod
def apply_changes(cls, target: dict[str, Any], changes: dict[str, Change]) -> None:
    """
    Applies a set of changes to the target dictionary.


    Args:
        target: The dictionary to apply the changes to.
        changes: A mapping of paths, separated by dots, to changes. The changes are applied in the dictionary on that level.
    """

    for path, change in changes.items():
        parts = path.split(".")
        cursor = target

        # Navigate down the dictionary
        for part in parts[:-1]:
            if part not in cursor:
                cursor[part] = {} # If the path was created because of ADDED
            cursor = cursor[part]

        last_key = parts[-1]

        if change.type == ChangeTypes.REMOVED:
            if last_key in cursor:
                del cursor[last_key]
        else:
            # UPDATED or ADDED
            cursor[last_key] = change.new

ChangeConflictException

Bases: Exception

Exception raised when a conflict between changes to a state is detected.

Source code in src/edgygraph/graph.py
349
350
351
352
353
class ChangeConflictException(Exception):
    """
    Exception raised when a conflict between changes to a state is detected.
    """
    pass