Source code for frontend._parse_
import re
import os
from typing import Any
[docs]
class CppRustDocStringParser:
[docs]
@staticmethod
def get_logging_docstrings(root: str) -> dict[str, dict[str, str]]:
result = {}
doc = {}
par_name = None
desc = ""
description_mode = False
def clear():
nonlocal doc
nonlocal desc
nonlocal description_mode
doc = {}
desc = ""
description_mode = False
def register(name):
nonlocal par_name
nonlocal doc
nonlocal desc
nonlocal description_mode
if "Name" in doc.keys():
if desc:
doc["Description"] = desc
if par_name:
doc["filename"] = f"{par_name}.{name}.out"
else:
doc["filename"] = f"{name}.out"
if "Map" in doc:
name = doc["Map"]
del doc["Map"]
result[name.replace("_", "-")] = doc.copy()
clear()
def extract_name(line):
start = line.find('"') + 1
end = line.find('"', start)
name = line[start:end].replace(" ", "_")
return name
def parse_line(line: str):
nonlocal par_name
nonlocal description_mode
nonlocal desc
nonlocal doc
if line.strip() == "":
clear()
skip_lables = ["File", "Author", "License", "https"]
if line.startswith("//"):
content = line.strip("//").strip()
if description_mode:
if desc:
desc += " "
desc += content
elif content.startswith("Description:"):
description_mode = True
elif ":" in content:
fields = content.split(":")
label = fields[0].strip()
for skip_label in skip_lables:
if label == skip_label:
return
content = fields[1].strip()
doc[label] = content
elif line.startswith("SimpleLog logging"):
par_name = ""
name = extract_name(line)
register(name)
par_name = name
elif line.startswith("/*== push"):
register(extract_name(line))
elif "logging.push(" in line:
register(extract_name(line))
elif "logging.mark(" in line:
register(extract_name(line))
for dirpath, _, filenames in os.walk(root):
for filename in filenames:
par_name = ""
if filename != "args.rs":
if filename.endswith(".cu") or filename.endswith(".rs"):
path = os.path.join(dirpath, filename)
lines = open(path, "r").readlines()
for line in lines:
line = line.strip()
if "#include" not in line:
parse_line(line)
result = dict(sorted(result.items()))
return result
class ParamParser:
@staticmethod
def get_default_params(path: str) -> dict[str, dict[str, Any]]:
"""Get the default parameters.
Args:
path (str): The path to the args.rs file.
Returns:
dict[str, Any]: The default parameters.
"""
att_pattern = re.compile(r"#\[(.*?)\]")
field_pattern = re.compile(r"pub\s+(\w+):\s*([^,]+),?")
struct_start_pattern = re.compile(r"^pub\s+struct\s+Args\s*\{")
struct_end_pattern = re.compile(r"^\s*\}")
curr_attributes = []
inside_struct = False
result = {}
doc = {}
var_type = None
description_mode = False
description = ""
def parse_line(line):
nonlocal doc
nonlocal description_mode
nonlocal description
nonlocal var_type
if line.strip().startswith("pub"):
parts = line.strip().split()
if len(parts) > 2:
var_type = parts[2].rstrip(",")
if line.strip().startswith("//"):
line = line.strip("// ").strip()
if "Do not list" in line:
doc["list"] = False
if line.startswith("Description:"):
description_mode = True
else:
if description_mode:
if description:
description += " "
description += line
else:
try:
fields = line.split(":")
field = fields[0].strip()
text = fields[1].strip()
doc[field] = text
except Exception as _:
pass
def clear_doc():
nonlocal doc
nonlocal description_mode
nonlocal description
nonlocal var_type
doc = {"list": True}
description_mode = False
description = ""
var_type = None
with open(path, "r") as f:
for line in f.readlines():
line = line.rstrip()
if not inside_struct:
if struct_start_pattern.match(line.strip()):
inside_struct = True
continue
parse_line(line)
if struct_end_pattern.match(line.strip()):
break
if not line.strip() or line.strip().startswith("//"):
continue
attr_match = att_pattern.match(line.strip())
if attr_match:
curr_attributes.append(attr_match.group(1).strip())
else:
field_match = field_pattern.match(line.strip())
if field_match:
field_name = field_match.group(1).replace("_", "-")
default_value = None
for attr in curr_attributes:
clap_match = re.match(r"clap\((.*?)\)", attr)
if clap_match:
args = clap_match.group(1)
arg_list = re.findall(
r'(?:[^,"]|"(?:\\.|[^"\\])*")+', args
)
for arg in arg_list:
arg = arg.strip()
if "=" in arg:
key, value = map(str.strip, arg.split("=", 1))
value = value.strip('"').strip("'")
if "default_value" in key:
default_value = value
if default_value is not None:
try:
float_value = float(default_value)
default_value = (
int(float_value)
if float_value.is_integer()
else float_value
)
except ValueError:
pass
doc["Description"] = description
result[field_name] = {
"value": default_value,
"type": var_type,
"doc": doc,
}
clear_doc()
curr_attributes = []
else:
curr_attributes = []
result = dict(sorted(result.items()))
return result