From fbad264cd92e01bbebd68d49ac1db4674ec22157 Mon Sep 17 00:00:00 2001 From: Riff Date: Tue, 9 Jan 2024 14:41:31 -0800 Subject: [PATCH] [sai-gen] Make SAI generator understand P4 IR to fetch counter references / callers. (#499) ## Problem Currently, the SAI attributes will be generated from table keys, action parameters and counters. Table keys and action parameters are naturally connected to the tables directly or indirectly in the P4 runtime json. However, counters are not, so we don't have enough information to find which SAI API we should add the attribute. To solve this, we have already added an annotation in SaiCounter, to specify the action names, so we can correlate the counters and its tables. However, this is tedious to do, if we have a lot of counters to add. ## What are we doing in this change This change parses the P4 IR json file to find which action this counter has been called to, so we can associate the counters and its action and tables automatically. Take an example, say we add a new counter in the metering action to calculate total bytes: ![image](https://github.com/sonic-net/DASH/assets/1533278/30762751-c5ce-4550-8151-d0d87c73da75) And without specifying the action names, like the other counters, the SAI attribute will be generated in the right place, because we now understand the program. ![image](https://github.com/sonic-net/DASH/assets/1533278/dc29db67-5bcb-4dec-9142-aaa1fbbc8a40) We choose the P4 IR as the input instead of the bmv2 json, because this approach is backend independent, and can apply to other backends too, such as dpdk, p4tc and etc, just like the p4 runtime json. So, if we decide to change our BM from BMv2 to other implementation, this program will still work. This change will help us add more counter for future scenarios, such as HA. --- dash-pipeline/SAI/Makefile | 1 + dash-pipeline/SAI/sai_api_gen.py | 192 ++++++++++++++---- .../dockerfiles/Dockerfile.saithrift-bldr | 2 +- 3 files changed, 158 insertions(+), 37 deletions(-) diff --git a/dash-pipeline/SAI/Makefile b/dash-pipeline/SAI/Makefile index 90e70e664..0403e8101 100644 --- a/dash-pipeline/SAI/Makefile +++ b/dash-pipeline/SAI/Makefile @@ -2,6 +2,7 @@ all: copysrc ./sai_api_gen.py \ /bmv2/dash_pipeline.bmv2/dash_pipeline_p4rt.json \ + --ir /bmv2/dash_pipeline.bmv2/dash_pipeline_ir.json \ --ignore-tables=appliance,eni_meter,slb_decap \ dash diff --git a/dash-pipeline/SAI/sai_api_gen.py b/dash-pipeline/SAI/sai_api_gen.py index 71e86bcfe..b7198a95c 100755 --- a/dash-pipeline/SAI/sai_api_gen.py +++ b/dash-pipeline/SAI/sai_api_gen.py @@ -6,7 +6,9 @@ import argparse import copy from jinja2 import Template, Environment, FileSystemLoader - from typing import (Type, Any, Dict, List, Optional) + from typing import (Type, Any, Dict, List, Optional, Callable) + import jsonpath_ng.ext as jsonpath_ext + import jsonpath_ng as jsonpath except ImportError as ie: print("Import failed for " + ie.name) exit(1) @@ -38,6 +40,107 @@ SAI_COUNTER_TAG: str = 'SaiCounter' SAI_TABLE_TAG: str = 'SaiTable' +# +# P4 IR parser and analyzer: +# +class P4IRTree: + @staticmethod + def from_file(path: str) -> 'P4IRTree': + with open(path, 'r') as f: + return P4IRTree(json.load(f)) + + def __init__(self, program: Dict[str, Any]) -> None: + self.program = program + + def walk(self, path: str, on_match: Callable[[Any, Any], None]) -> None: + jsonpath_exp = jsonpath_ext.parse(path) + for match in jsonpath_exp.find(self.program): + on_match(match) + + +class P4IRVarInfo: + @staticmethod + def from_ir(ir_def_node: Any) -> 'P4IRVarInfo': + return P4IRVarInfo( + ir_def_node["Node_ID"], + ir_def_node["name"], + ir_def_node["Source_Info"]["source_fragment"], + ir_def_node["type"]["path"]["name"]) + + def __init__(self, ir_id: int, ir_name: str, code_name: str, type_name: str) -> None: + self.ir_id = ir_id + self.ir_name = ir_name + self.code_name = code_name + self.type_name = type_name + + def __str__(self) -> str: + return f"ID = {self.ir_id}, Name = {self.ir_name}, VarName = {self.code_name}, Type = {self.type_name}" + + +class P4IRVarRefInfo: + @staticmethod + def from_ir(ir_ref_node: Any, ir_var_info: P4IRVarInfo) -> 'P4IRVarRefInfo': + return P4IRVarRefInfo( + ir_var_info, + ir_ref_node["Node_ID"], + ir_ref_node["Node_Type"], + ir_ref_node["name"]) + + def __init__(self, var: P4IRVarInfo, caller_id: int, caller_type: str, caller: str) -> None: + self.var = var + self.caller_id = caller_id + self.caller_type = caller_type + self.caller = caller + + def __str__(self) -> str: + return f"VarName = {self.var.code_name}, CallerID = {self.caller_id}, CallerType = {self.caller_type}, Caller = {self.caller}" + + +class P4VarRefGraph: + def __init__(self, ir: P4IRTree) -> None: + self.ir = ir + self.counters: Dict[str, P4IRVarInfo] = {} + self.var_refs: Dict[str, List[P4IRVarRefInfo]] = {} + self.__build_graph() + + def __build_graph(self) -> None: + self.__build_counter_list() + self.__build_counter_caller_mapping() + pass + + def __build_counter_list(self) -> None: + def on_counter_definition(match: jsonpath.DatumInContext) -> None: + ir_value = P4IRVarInfo.from_ir(match.value) + self.counters[ir_value.ir_name] = ir_value + print(f"Counter definition found: {ir_value}") + + self.ir.walk('$..*[?Node_Type = "Declaration_Instance" & type.Node_Type = "Type_Name" & type.path.name = "counter"]', on_counter_definition) + + def __build_counter_caller_mapping(self) -> None: + # Build the mapping from counter name to its caller. + def on_counter_invocation(match: jsonpath.DatumInContext) -> None: + var_ir_name: str = match.value["expr"]["path"]["name"] + if var_ir_name not in self.counters: + return + + # Walk through the parent nodes to find the closest action or control block. + cur_node = match + while cur_node.context is not None: + cur_node = cur_node.context + if "Node_Type" not in cur_node.value: + continue + + if cur_node.value["Node_Type"] in ["P4Action", "P4Control"]: + var = self.counters[var_ir_name] + var_ref = P4IRVarRefInfo.from_ir(cur_node.value, var) + self.var_refs.setdefault(var.code_name, []).append(var_ref) + print(f"Counter reference found: {var_ref}") + break + + # Get all nodes with Node_Type = and name = "counter". This will be the nodes that represent the counter calls. + self.ir.walk('$..*[?Node_Type = "Member" & member = "count"]', on_counter_invocation) + + # # SAI parser decorators: # @@ -214,6 +317,7 @@ def __get_range_list_match_key_sai_type(key_size: int) -> str: class SAIObject: def __init__(self): # Properties from P4Runtime preamble + self.raw_name: str = '' self.name: str = '' self.id: int = 0 self.alias: str = '' @@ -233,18 +337,22 @@ def parse_basic_info_if_exists(self, p4rt_object: Dict[str, Any]) -> None: ''' if PREAMBLE_TAG in p4rt_object: preamble = p4rt_object[PREAMBLE_TAG] - self.id = preamble['id'] - self.name = preamble['name'] - self.alias = preamble['alias'] + self.id = int(preamble['id']) + self.name = str(preamble['name']) + self.alias = str(preamble['alias']) else: - self.id = p4rt_object['id'] if 'id' in p4rt_object else self.id - self.name = p4rt_object['name'] if 'name' in p4rt_object else self.name + self.id = int(p4rt_object['id']) if 'id' in p4rt_object else self.id + self.name = str(p4rt_object['name']) if 'name' in p4rt_object else self.name # We only care about the last piece of the name, which is the actual object name. if '.' in self.name: name_parts = self.name.split('.') self.name = name_parts[-1] + # We save the raw name here, because "name" can be override by annotation for API generation purpose, and the raw name will help us + # to find the correlated P4 infomation from either Runtime or IR. + self.raw_name = self.name + return def _parse_sai_common_annotation(self, p4rt_anno: Dict[str, Any]) -> None: @@ -256,10 +364,10 @@ def _parse_sai_common_annotation(self, p4rt_anno: Dict[str, Any]) -> None: { "key": "type", "value": { "stringValue": "sai_ip_addr_family_t" } } ''' if p4rt_anno['key'] == 'name': - self.name = p4rt_anno['value']['stringValue'] + self.name = str(p4rt_anno['value']['stringValue']) return True elif p4rt_anno['key'] == 'order': - self.order = p4rt_anno['value']['int64Value'] + self.order = str(p4rt_anno['value']['int64Value']) return True return False @@ -282,7 +390,7 @@ def parse_p4rt(self, p4rt_member: Dict[str, Any]) -> None: { "name": "INVALID", "value": "AAA=" } ''' - self.p4rt_value = p4rt_member["value"] + self.p4rt_value = str(p4rt_member["value"]) @sai_parser_from_p4rt @@ -313,7 +421,7 @@ def parse_p4rt(self, p4rt_enum: Dict[str, Any]) -> None: print("Parsing enum: " + self.name) self.name = self.name[:-2] - self.bitwidth = p4rt_enum['underlyingType'][BITWIDTH_TAG] + self.bitwidth = int(p4rt_enum['underlyingType'][BITWIDTH_TAG]) self.members = [SAIEnumMember.from_p4rt(enum_member) for enum_member in p4rt_enum[MEMBERS_TAG]] # Register enum type info. @@ -362,19 +470,19 @@ def _parse_sai_table_attribute_annotation(self, p4rt_anno_list: Dict[str, Any]) if self._parse_sai_common_annotation(kv): continue elif kv['key'] == 'type': - self.type = kv['value']['stringValue'] + self.type = str(kv['value']['stringValue']) elif kv['key'] == 'default_value': # "default" is a reserved keyword and cannot be used. - self.default = kv['value']['stringValue'] + self.default = str(kv['value']['stringValue']) elif kv['key'] == 'isresourcetype': - self.isresourcetype = kv['value']['stringValue'] + self.isresourcetype = str(kv['value']['stringValue']) elif kv['key'] == 'isreadonly': - self.isreadonly = kv['value']['stringValue'] + self.isreadonly = str(kv['value']['stringValue']) elif kv['key'] == 'objects': - self.object_name = kv['value']['stringValue'] + self.object_name = str(kv['value']['stringValue']) elif kv['key'] == 'skipattr': - self.skipattr = kv['value']['stringValue'] + self.skipattr = str(kv['value']['stringValue']) elif kv['key'] == 'match_type': - self.match_type = kv['value']['stringValue'] + self.match_type = str(kv['value']['stringValue']) else: raise ValueError("Unknown attr annotation " + kv['key']) @@ -404,7 +512,7 @@ def __init__(self): self.as_attr: bool = False self.param_actions: List[str] = [] - def parse_p4rt(self, p4rt_counter: Dict[str, Any]) -> None: + def parse_p4rt(self, p4rt_counter: Dict[str, Any], var_ref_graph: P4VarRefGraph) -> None: ''' This method parses the P4Runtime counter object and populates the SAI counter object. @@ -422,19 +530,27 @@ def parse_p4rt(self, p4rt_counter: Dict[str, Any]) -> None: "size": "262144" } ''' - # print("Parsing counter: " + self.name) + print("Parsing counter: " + self.name) self.__parse_sai_counter_annotation(p4rt_counter) counter_storage_type = SAITypeSolver.get_object_sai_type(self.bitwidth) self.type = counter_storage_type.name self.field = counter_storage_type.sai_attribute_value_field - counter_unit = p4rt_counter['spec']['unit'] + counter_unit: str = p4rt_counter['spec']['unit'] if counter_unit in ['BYTES', 'PACKETS', 'BOTH']: self.counter_type = counter_unit.lower() else: raise ValueError(f'Unknown counter unit: {counter_unit}') + # If actions are specified by annotation, then we skip finding the referenced actions from the IR. + if len(self.param_actions) == 0 and self.raw_name in var_ref_graph.var_refs: + for ref in var_ref_graph.var_refs[self.raw_name]: + if ref.caller_type == 'P4Action': + self.param_actions.append(ref.caller) + + print(f"Counter {self.name} is referenced by {self.param_actions}") + return def __parse_sai_counter_annotation(self, p4rt_counter: Dict[str, Any]) -> None: @@ -460,7 +576,7 @@ def __parse_sai_counter_annotation(self, p4rt_counter: Dict[str, Any]) -> None: if self._parse_sai_common_annotation(kv): continue elif kv['key'] == 'action_names': - self.param_actions = kv['value']['stringValue'].split(",") + self.param_actions = str(kv['value']['stringValue']).split(",") elif kv['key'] == 'as_attr': self.as_attr = True if kv['value']['stringValue'] == "true" else False else: @@ -496,13 +612,13 @@ def parse_p4rt(self, p4rt_table_key: Dict[str, Any]) -> None: } ''' - self.bitwidth = p4rt_table_key[BITWIDTH_TAG] + self.bitwidth = int(p4rt_table_key[BITWIDTH_TAG]) # print("Parsing table key: " + self.name) if OTHER_MATCH_TYPE_TAG in p4rt_table_key: - self.match_type = p4rt_table_key[OTHER_MATCH_TYPE_TAG].lower() + self.match_type = str(p4rt_table_key[OTHER_MATCH_TYPE_TAG].lower()) elif MATCH_TYPE_TAG in p4rt_table_key: - self.match_type = p4rt_table_key[MATCH_TYPE_TAG].lower() + self.match_type = str(p4rt_table_key[MATCH_TYPE_TAG].lower()) else: raise ValueError(f'No valid match tag found') @@ -583,7 +699,7 @@ def parse_p4rt(self, p4rt_table_action_param: Dict[str, Any]) -> None: { "id": 1, "name": "dst_vnet_id", "bitwidth": 16 } ''' - self.bitwidth = p4rt_table_action_param[BITWIDTH_TAG] + self.bitwidth = int(p4rt_table_action_param[BITWIDTH_TAG]) # print("Parsing table action param: " + self.name) if STRUCTURED_ANNOTATIONS_TAG in p4rt_table_action_param: @@ -691,15 +807,15 @@ def __parse_sai_table_annotations(self, p4rt_table_preamble: Dict[str, Any]) -> if self._parse_sai_common_annotation(kv): continue elif kv['key'] == 'isobject': - self.is_object = kv['value']['stringValue'] + self.is_object = str(kv['value']['stringValue']) elif kv['key'] == 'ignored': self.ignored = True elif kv['key'] == 'stage': - self.stage = kv['value']['stringValue'] + self.stage = str(kv['value']['stringValue']) elif kv['key'] == 'api': - self.api_name = kv['value']['stringValue'] + self.api_name = str(kv['value']['stringValue']) elif kv['key'] == 'api_type': - self.api_type = kv['value']['stringValue'] + self.api_type = str(kv['value']['stringValue']) if self.is_object == None: self.is_object = 'false' @@ -846,26 +962,26 @@ def __init__(self): self.sai_apis: List[DASHAPISet] = [] @staticmethod - def from_p4rt_file(p4rt_json_file_path: str, ignore_tables: List[str]) -> 'DASHSAIExtensions': + def from_p4rt_file(p4rt_json_file_path: str, ignore_tables: List[str], var_ref_graph: P4VarRefGraph) -> 'DASHSAIExtensions': print("Parsing SAI APIs BMv2 P4Runtime Json file: " + p4rt_json_file_path) with open(p4rt_json_file_path) as p4rt_json_file: p4rt = json.load(p4rt_json_file) - return DASHSAIExtensions.from_p4rt(p4rt, name = 'dash_sai_apis', ignore_tables = ignore_tables) + return DASHSAIExtensions.from_p4rt(p4rt, name = 'dash_sai_apis', ignore_tables = ignore_tables, var_ref_graph = var_ref_graph) - def parse_p4rt(self, p4rt_value: Dict[str, Any], ignore_tables: List[str]) -> None: + def parse_p4rt(self, p4rt_value: Dict[str, Any], ignore_tables: List[str], var_ref_graph) -> None: self.__parse_sai_enums_from_p4rt(p4rt_value) - self.__parse_sai_counters_from_p4rt(p4rt_value) + self.__parse_sai_counters_from_p4rt(p4rt_value, var_ref_graph) self.__parse_sai_apis_from_p4rt(p4rt_value, ignore_tables) def __parse_sai_enums_from_p4rt(self, p4rt_value: Dict[str, Any]) -> None: all_p4rt_enums = p4rt_value[TYPE_INFO_TAG][SERIALIZABLE_ENUMS_TAG] self.sai_enums = [SAIEnum.from_p4rt(enum_value, name = enum_name) for enum_name, enum_value in all_p4rt_enums.items()] - def __parse_sai_counters_from_p4rt(self, p4rt_value: Dict[str, Any]) -> None: + def __parse_sai_counters_from_p4rt(self, p4rt_value: Dict[str, Any], var_ref_graph: P4VarRefGraph) -> None: all_p4rt_counters = p4rt_value[COUNTERS_TAG] for p4rt_counter in all_p4rt_counters: - counter = SAICounter.from_p4rt(p4rt_counter) + counter = SAICounter.from_p4rt(p4rt_counter, var_ref_graph) self.sai_counters.append(counter) def __parse_sai_apis_from_p4rt(self, program: Dict[str, Any], ignore_tables: List[str]) -> None: @@ -1113,6 +1229,7 @@ def __get_uniq_sai_api(self, sai_api: DASHAPISet) -> None: parser = argparse.ArgumentParser(description='P4 SAI API generator') parser.add_argument('filepath', type=str, help='Path to P4 program RUNTIME JSON file') parser.add_argument('apiname', type=str, help='Name of the new SAI API') + parser.add_argument('--ir', type=str, help="Path to P4 program IR JSON file") parser.add_argument('--print-sai-lib', type=bool) parser.add_argument('--ignore-tables', type=str, default='', help='Comma separated list of tables to ignore') args = parser.parse_args() @@ -1124,8 +1241,11 @@ def __get_uniq_sai_api(self, sai_api: DASHAPISet) -> None: os.chdir(os.path.dirname(os.path.realpath(__file__))) + p4ir = P4IRTree.from_file(args.ir) + var_ref_graph = P4VarRefGraph(p4ir) + # Parse SAI data from P4 runtime json file - dash_sai_exts = DASHSAIExtensions.from_p4rt_file(p4rt_file_path, args.ignore_tables.split(',')) + dash_sai_exts = DASHSAIExtensions.from_p4rt_file(p4rt_file_path, args.ignore_tables.split(','), var_ref_graph) dash_sai_exts.post_parsing_process() if args.print_sai_lib: diff --git a/dash-pipeline/dockerfiles/Dockerfile.saithrift-bldr b/dash-pipeline/dockerfiles/Dockerfile.saithrift-bldr index 874252b43..c8233327a 100644 --- a/dash-pipeline/dockerfiles/Dockerfile.saithrift-bldr +++ b/dash-pipeline/dockerfiles/Dockerfile.saithrift-bldr @@ -27,7 +27,7 @@ ENV DASH_SAIGEN_DEPS python3 python3-pip RUN apt-get update -q && \ apt-get install -y --no-install-recommends $SAI_PTF_DEPS $DASH_SAIGEN_DEPS && \ - pip3 install ctypesgen jinja2 + pip3 install ctypesgen jinja2 jsonpath-ng ENV SAI_THRIFT_DEPS automake bison flex g++ git libboost-all-dev libevent-dev libssl-dev libtool make pkg-config