Skip to content

Commit

Permalink
[StubGen] Intruduce length:return (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebaszm authored Jan 11, 2024
1 parent 847446f commit c1dc18c
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 59 deletions.
16 changes: 16 additions & 0 deletions ProxyStubGenerator/Log.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def Info(self, text, file=""):
self.infos.append("%s: %s%s: %s%s%s" % (self.name, self.cinfo, self.creset, file, ": " if file else "", text))
self.__Print(self.infos[-1])

def InfoLine(self, obj, text, file=""):
if self.show_infos:
if not file: file = self.file
try:
if not file: file = os.path.basename(obj.parser_file)
line = str(obj.parser_line)
except:
try:
file = os.path.basename(obj.parent.parser_file)
line = obj.parent.parser_line
except:
file = ""
line = ""
self.infos.append("%s: %s%s: %s%s" % (self.name, self.cinfo, self.creset, ("%s(%s): " % (file, line)) if file else "", text))
self.__Print(self.infos[-1])

def DocIssue(self, text, file=""):
if self.show_doc_issues:
if not file: file = self.file
Expand Down
172 changes: 113 additions & 59 deletions ProxyStubGenerator/StubGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def GenerateLuaData(emit, interfaces_list, enums_list, source_file=None, tree=No
retval = []
items = [ name ]

def ParseLength(param, length, vars):
def ParseLength(param, length, retval, vars):
def _Convert(size):
if size == "char":
return "8"
Expand All @@ -280,6 +280,8 @@ def _Convert(size):
if len(length) == 1:
if length[0] == "void":
return [_Convert(param.Type().size), None]
elif length[0] == "return":
return [_Convert(retval.type.Type().size), "return"]

for v in vars:
if v.name == length[0]:
Expand Down Expand Up @@ -312,7 +314,7 @@ def Convert(paramtype, retval, vars, hresult=False):
length_param = None

if param.IsPointer():
parsed = ParseLength(param, meta.length if meta.length else meta.maxlength, vars)
parsed = ParseLength(param, meta.length if meta.length else meta.maxlength, retval, vars)

if parsed[1]:
length_param = parsed[1]
Expand Down Expand Up @@ -455,7 +457,15 @@ def Convert(paramtype, retval, vars, hresult=False):

text.extend(rv)

retval.append(" { %s }" % ", ".join(text))
skip = False

for v in m.vars:
if (v.meta.length and (v.meta.length[0] == "return")):
skip = True
break

if not skip:
retval.append(" { %s }" % ", ".join(text))

for p in m.vars:
param = Convert(p, m.retval, m.vars)
Expand All @@ -469,16 +479,22 @@ def Convert(paramtype, retval, vars, hresult=False):
text.extend(param)

skip = False
skip2 = False

for v in m.vars:
if (v.meta.length and (v.meta.length[0] == p.name)) and not v.meta.maxlength:
# Will not be on the wire!
if (v.meta.length and (v.meta.length[0] == p.name)):
# Will not be on the wire on inbound!
skip = True

if not skip:
if (not v.meta.output or v.meta.input):
# Will not be on the wire on inbound and outbound!
skip2 = True

if not skip2:
if p.meta.input or not p.meta.output:
if not skip:
params.append(" { %s }" % ", ".join(text))
params.append(" { %s }" % ", ".join(text))

if not skip:
if p.meta.output:
retval.append(" { %s }" % ", ".join(text))

Expand Down Expand Up @@ -623,8 +639,19 @@ def _FindLength(length_name, variable_name):
return EmitIdentifier(-2, interface, \
CppParser.Temporary(self.identifier.parent, ["uint8_t", variable_name], ["sizeof(%s)" % self.kind], []), variable_name)

elif length_name[0] == "return":
result = "result"

if isinstance(self.identifier.parent.retval.type.Type(), CppParser.Integer):
if self.identifier.parent.retval.type.Type().signed:
result = "(result > 0? result : 0)"

return EmitIdentifier(-3, interface, \
CppParser.Temporary(self.identifier.parent, [self.identifier.parent.retval.type, "result"], [result], []), "result")

elif length_name[0] == "sizeof" and len(length_name) == 4:
matches = [v for v in self.identifier.parent.vars if v.name == length_name[2]]

if matches:
return EmitIdentifier(-2, interface, \
CppParser.Temporary(self.identifier.parent, ["uint8_t", variable_name], ["sizeof(_%s)" % matches[0].name], []), variable_name)
Expand Down Expand Up @@ -671,6 +698,7 @@ def _FindLength(length_name, variable_name):
no_length_warnings = is_variable_length_of
is_return_value = (override_name == "result")

self.is_on_wire = True
self.identifier = identifier
self.identifier_type = type_
self.identifier_kind = self.identifier_type.Type()
Expand Down Expand Up @@ -742,20 +770,20 @@ def _FindLength(length_name, variable_name):
# If have a input length, assume it's a max-length parameter even if not stated explicitly
if (self.is_buffer and not self.max_length and self.length.is_input):
self.max_length = self.length
if self.is_output_only and not self.length.is_output:
# It appears it should've been @maxlength instead of @length... Fix this
self.length = None
log.WarnLine(self.identifier, "'%s': this buffer is output-only, hence expected @maxlength here, not @length" % self.trace_proto)
elif self.is_input and self.is_output:

if self.is_input and self.is_output:
log.WarnLine(self.identifier, \
"'%s': maximum length of this input/output buffer is assumed to be same as @length, use @maxlength to disambiguate" % \
(self.trace_proto))
else:
log.InfoLine(self.identifier, "'%s': @maxlength not specified, assuming same as @length" % self.trace_proto)

if (self.is_buffer and self.is_output and not self.max_length):
raise TypenameError(self.identifier, "'%s': can't deduce maximum length of the buffer, use @maxlength" % self.trace_proto)

#if (self.is_buffer and self.max_length and not self.length):
# log.WarnLine(self.identifier, "'%s': length of returned buffer is not specified" % self.trace_proto)
if (self.is_buffer and self.max_length and not self.length):
log.WarnLine(self.identifier, "'%s': length of returned buffer is not specified; using @maxlength, but this may be inefficient" % self.trace_proto)
self.length = self.max_length

# Is it a hresult?
self.is_hresult = self.identifier_type.TypeName().endswith("Core::hresult") \
Expand Down Expand Up @@ -978,10 +1006,10 @@ def read_rpc_type(self):
# Raw buffers
elif self.is_buffer:
assert self.length or self.max_length, "Invalid type for buffer"
if self.length:
return "Buffer<%s>(%s, %s)" % (self.length.type_name, self.length.as_rvalue, self.as_lvalue)
elif self.max_length:
if self.max_length:
return "Buffer<%s>(%s, %s)" % (self.max_length.type_name, self.max_length.as_rvalue, self.as_lvalue)
elif self.length:
return "Buffer<%s>(%s, %s)" % (self.length.type_name, self.length.as_rvalue, self.as_lvalue)
else:
Unreachable()

Expand Down Expand Up @@ -1017,7 +1045,7 @@ def write_rpc_type(self):
# Raw buffers
elif self.is_buffer:
assert self.max_length, "Invalid type for buffer " + self.name
return "Buffer<%s>(%s, %s)" % (self.max_length.type_name, self.max_length.as_rvalue, self.as_rvalue)
return "Buffer<%s>(%s, %s)" % (self.length.type_name, self.length.as_rvalue, self.as_rvalue)

# Strings
elif self.is_string:
Expand Down Expand Up @@ -1110,6 +1138,8 @@ def PrepareParams(method, interface):

for index, var in enumerate(method.vars):
params.append(EmitParam(interface, var, index = index))
if var.meta.length and var.meta.length[0] == "return":
retval.is_on_wire = False

if retval:
output_params.append(retval)
Expand Down Expand Up @@ -1213,6 +1243,8 @@ def EmitStubMethodImplementation(index, method, interface_name, interface, retva
has_hresult = retval and retval.is_hresult

def ReadParameter(p):
assert(p)
assert(p.is_on_wire)
if p.is_compound:
kind = p.kind.Merge()
if not p.suppress_type:
Expand Down Expand Up @@ -1242,6 +1274,7 @@ def ReadParameter(p):
CheckRange(p, p)

def TemporaryParameter(p):
assert(p)
if p.is_buffer:
output_buffers.append(p)

Expand All @@ -1252,6 +1285,7 @@ def TemporaryParameter(p):
return 0

def AllocateBuffer(p):
assert(p)
assert p.is_buffer
assert p.max_length

Expand All @@ -1274,8 +1308,11 @@ def AllocateBuffer(p):

if has_buffer_reuse:
temp_buffer = AuxIdentifier(CppParser.Void(), CppParser.Ref.POINTER, vars["tempbuffer"])
emit.Line("if (%s > %s) {" % (p.max_length.as_rvalue, p.length.as_rvalue))
emit.IndentInc()

if (p.max_length.as_rvalue != p.length.as_rvalue):
emit.Line("if (%s > %s) {" % (p.max_length.as_rvalue, p.length.as_rvalue))
emit.IndentInc()

emit.Line("%s{};" % temp_buffer.as_temporary_no_cv)
emit.Line()

Expand Down Expand Up @@ -1324,8 +1361,11 @@ def AllocateBuffer(p):
emit.Line("}")
emit.Line()
emit.Line("%s = static_cast<%s>(%s);" % (p.as_rvalue, p.proto, temp_buffer.as_rvalue))
emit.IndentDec()
emit.Line("}")

if (p.max_length.as_rvalue != p.length.as_rvalue):
emit.IndentDec()
emit.Line("}")

emit.Line()

emit.Line()
Expand All @@ -1338,20 +1378,24 @@ def AllocateBuffer(p):
emit.Line("}")

def ReleaseBuffer(p, large_buffer):
assert(p)
assert p.is_buffer
if (p.max_length and (p.length.type.Type().size == "long")):
emit.Line("RPC::Administrator::Instance().Free(%s);" % large_buffer.as_rvalue)

def WriteParameter(p):
if p.is_compound:
kind = p.kind.Merge()
params = [EmitParam(interface, v, (p.name + "." + v.name)) for v in kind.vars]
for pp in params:
WriteParameter(pp)
else:
emit.Line("%s.%s;" % (vars["writer"], p.write_rpc_type))
assert(p)
if p.is_on_wire:
if p.is_compound:
kind = p.kind.Merge()
params = [EmitParam(interface, v, (p.name + "." + v.name)) for v in kind.vars]
for pp in params:
WriteParameter(pp)
else:
emit.Line("%s.%s;" % (vars["writer"], p.write_rpc_type))

def AcquireInterface(p):
assert(p)
assert p.proxy
assert p.proxy_instance
emit.Line("%s* %s = nullptr;" %(p.type_name, p.proxy_instance))
Expand All @@ -1371,6 +1415,7 @@ def AcquireInterface(p):
emit.Line("}")

def ReleaseProxy(p):
assert(p)
assert p.proxy
emit.Line("if (%s != nullptr) {" % p.proxy)
emit.IndentInc()
Expand All @@ -1379,6 +1424,7 @@ def ReleaseProxy(p):
emit.Line("}")

def RegisterInterface(p):
assert(p)
assert p.interface_id
if not isinstance(p.interface_id, AuxIdentifier):
# Interface ID comes from a parameter
Expand Down Expand Up @@ -1499,7 +1545,7 @@ def CallImplementation(retval, params):
emit.Line("}")

if has_restricted_parameters and (not retval or not retval.is_hresult):
log.WarnLine(method, "'%s': method is using restricted parameters but its return value type is not 'Core::hresult'" % method.name)
log.InfoLine(method, "'%s': method is using restricted parameters, but its return value type is not 'Core::hresult'" % method.name)

def EmitStubMethod(index, last, method, interface_name, interface, prepared_params):
retval, params, input_params, output_params, proxy_params, return_proxy_params = prepared_params
Expand Down Expand Up @@ -1614,6 +1660,8 @@ def EmitProxyMethodImplementation(index, method, interface_name, interface, retv
emit.Line()

def WriteParameter(p):
assert(p)
assert(p.is_on_wire)
if p.is_compound:
kind = p.kind.Merge()
params = [EmitParam(interface, v, (p.name + "." + v.name), True) for v in kind.vars]
Expand All @@ -1624,35 +1672,37 @@ def WriteParameter(p):
emit.Line("writer.%s;" % p.write_rpc_type)

def ReadParameter(p):
if p.is_compound:
kind = p.kind.Merge()
params = [EmitParam(interface, v, (p.name + "." + v.name), True) for v in kind.vars]
for pp in params:
ReadParameter(pp)
assert(p)
if p.is_on_wire:
if p.is_compound:
kind = p.kind.Merge()
params = [EmitParam(interface, v, (p.name + "." + v.name), True) for v in kind.vars]
for pp in params:
ReadParameter(pp)

elif p.return_proxy:
emit.Line("%s = reinterpret_cast<%s>(Interface(%s.%s, %s));" % \
(p.as_lvalue, p.proto_no_cv, vars["reader"], p.read_rpc_type, p.interface_id.as_rvalue))

elif p.is_buffer:
CheckFrame(p)
CheckSize(p)

if p.length and p.length.is_output:
emit.Line("%s = %s.%s;" % (p.length.as_lvalue, vars["reader"], p.read_rpc_type))
else:
# No one's interested in the return length, perhaps it's sent via method's return value
emit.Line("%s.%s;" % (vars["reader"], p.read_rpc_type))

elif p.return_proxy:
emit.Line("%s = reinterpret_cast<%s>(Interface(%s.%s, %s));" % \
(p.as_lvalue, p.proto_no_cv, vars["reader"], p.read_rpc_type, p.interface_id.as_rvalue))
elif p.is_string:
CheckFrame(p)
CheckSize(p)
emit.Line("%s = %s.%s;" % (p.as_lvalue, vars["reader"], p.read_rpc_type))

elif p.is_buffer:
CheckFrame(p)
CheckSize(p)

if p.length:
emit.Line("%s = %s.%s;" % (p.length.as_lvalue, vars["reader"], p.read_rpc_type))
else:
# No one's interested in the return length, perhaps it's sent via method's return value
emit.Line("%s.%s;" % (vars["reader"], p.read_rpc_type))

elif p.is_string:
CheckFrame(p)
CheckSize(p)
emit.Line("%s = %s.%s;" % (p.as_lvalue, vars["reader"], p.read_rpc_type))

else:
CheckFrame(p)
emit.Line("%s = %s.%s;" % (p.as_lvalue, vars["reader"], p.read_rpc_type))
CheckRange(p, p)
CheckFrame(p)
emit.Line("%s = %s.%s;" % (p.as_lvalue, vars["reader"], p.read_rpc_type))
CheckRange(p, p)

if input_params:
emit.Line("RPC::Data::Frame::Writer %s(%s->Parameters().Writer());" % (vars["writer"], vars["message"]))
Expand Down Expand Up @@ -2070,8 +2120,12 @@ def GenerateIdentification(name):
print(" @interface:{expr} - specifies a parameter or value indicating interface ID value for void* interface passing")
print(" @length:{expr} - specifies a buffer length value (a constant, a parameter name or a math expression)")
print(" @maxlength:{expr} - specifies a maximum buffer length value (a constant, a parameter name or a math expression),")
print(" if not specified, @length is considered as maximum length; use round parenthesis for expressions",)
print(" e.g.: @length:bufferSize @length:(width*height*4), @length:(sizeof(uint64_t)), @length:void")
print(" if @maxlength is not specified, expresion from @length is used")
print("")
print(" Examples:")
print(" @length:bufferSize, @length:32, @length:(width*height*4), @length:(sizeof(uint64_t))")
print(" @length:return - length is carried in the return value")
print(" @length:void - length is the size of one element")
print("")
print("JSON-RPC-related parameters:")
print(" @json - takes a C++ class in for JSON-RPC generation")
Expand Down

0 comments on commit c1dc18c

Please sign in to comment.