Skip to content

Commit

Permalink
Adding ability to access and iterate over function attributes. (#259)
Browse files Browse the repository at this point in the history
* Adding ability to access and iterate over function attributes.
  • Loading branch information
idavis authored Oct 20, 2023
1 parent 2af2995 commit 26885c9
Show file tree
Hide file tree
Showing 8 changed files with 481 additions and 29 deletions.
2 changes: 1 addition & 1 deletion eng/psakefile.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ task check-environment {
}

Assert ((Test-InVirtualEnvironment) -eq $true) ($env_message -Join ' ')
exec { & $Python -m pip install pip~=23.1 }
exec { & $Python -m pip install pip~=23.3 }
}

task init -depends check-environment {
Expand Down
16 changes: 13 additions & 3 deletions eng/utils.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,21 @@ function Test-AllowedToDownloadLlvm {
}

function Test-InCondaEnvironment {
(Test-Path env:\CONDA_PREFIX)
$found = (Test-Path env:\CONDA_PREFIX)
if ($found) {
$condaPrefix = $env:CONDA_PREFIX
Write-BuildLog "Found conda environment: $condaPrefix"
}
$found
}

function Test-InVenvEnvironment {
(Test-Path env:\VIRTUAL_ENV)
$found = (Test-Path env:\VIRTUAL_ENV)
if ($found) {
$venv = $env:VIRTUAL_ENV
Write-BuildLog "Found venv environment: $venv"
}
$found
}

function Test-InVirtualEnvironment {
Expand Down Expand Up @@ -301,5 +311,5 @@ function install-llvm {
if ($clear_cache_var) {
Remove-Item -Path Env:QIRLIB_CACHE_DIR
}
}
}
}
3 changes: 3 additions & 0 deletions pyqir/pyqir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from pyqir._simple import SimpleModule
from pyqir._entry_point import entry_point
from pyqir._basicqis import BasicQisBuilder
from pyqir._constants import ATTR_FUNCTION_INDEX, ATTR_RETURN_INDEX

__all__ = [
"ArrayType",
Expand Down Expand Up @@ -117,4 +118,6 @@
"result_id",
"result_type",
"result",
"ATTR_FUNCTION_INDEX",
"ATTR_RETURN_INDEX",
]
5 changes: 5 additions & 0 deletions pyqir/pyqir/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

ATTR_RETURN_INDEX = 0
ATTR_FUNCTION_INDEX = 4294967295 # -1 u32
30 changes: 27 additions & 3 deletions pyqir/pyqir/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
# Licensed under the MIT License.

from enum import Enum
from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import (
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)

class ArrayType(Type):
"""An array type."""
Expand All @@ -19,6 +28,10 @@ class ArrayType(Type):
class Attribute:
"""An attribute."""

@property
def string_kind(self) -> str:
"""The kind of this attribute as a string."""
...
@property
def string_value(self) -> Optional[str]:
"""The value of this attribute as a string, or `None` if this is not a string attribute."""
Expand All @@ -44,7 +57,13 @@ class AttributeList:
"""The attributes for the function itself."""
...

class AttributeSet:
class AttributeIterator(Iterator[Attribute]):
"""An iterator of attributes for a specific part of a function."""

def __iter__(self) -> Iterator[Attribute]: ...
def __next__(self) -> Attribute: ...

class AttributeSet(Iterable[Attribute]):
"""A set of attributes for a specific part of a function."""

def __contains__(self, item: str) -> bool:
Expand All @@ -63,6 +82,7 @@ class AttributeSet:
:returns: The attribute.
"""
...
def __iter__(self) -> Iterator[Attribute]: ...

class BasicBlock(Value):
"""A basic block."""
Expand Down Expand Up @@ -1223,13 +1243,17 @@ def if_result(
...

def add_string_attribute(
function: Function, kind: str, value: Optional[str] = None
function: Function,
kind: str,
value: Optional[str] = None,
index: Optional[int] = None,
) -> bool:
"""
Adds a string attribute to the given function.
:param function: The function.
:param key: The attribute key.
:param value: The attribute value.
:param index: The optional attribute index, defaults to the function index.
"""
...
64 changes: 51 additions & 13 deletions pyqir/src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use pyo3::{
types::{PyBytes, PyLong, PyString},
PyRef,
};
use qirlib::values;
use qirlib::values::{self, get_string_attribute_kind, get_string_attribute_value};
use std::{
borrow::Borrow,
collections::hash_map::DefaultHasher,
Expand All @@ -33,6 +33,7 @@ use std::{
ops::Deref,
ptr::NonNull,
slice, str,
vec::IntoIter,
};

/// A value.
Expand Down Expand Up @@ -522,21 +523,20 @@ pub(crate) struct Attribute(LLVMAttributeRef);

#[pymethods]
impl Attribute {
/// The id of this attribute as a string.
///
/// :type: str
#[getter]
fn string_kind(&self) -> String {
unsafe { get_string_attribute_kind(self.0) }
}

/// The value of this attribute as a string, or `None` if this is not a string attribute.
///
/// :type: typing.Optional[str]
#[getter]
fn string_value(&self) -> Option<&str> {
unsafe {
if LLVMIsStringAttribute(self.0) == 0 {
None
} else {
let mut len = 0;
let value = LLVMGetStringAttributeValue(self.0, &mut len).cast();
let value = slice::from_raw_parts(value, len.try_into().unwrap());
Some(str::from_utf8(value).unwrap())
}
}
fn string_value(&self) -> Option<String> {
unsafe { get_string_attribute_value(self.0) }
}
}

Expand Down Expand Up @@ -588,6 +588,24 @@ pub(crate) struct AttributeSet {
index: LLVMAttributeIndex,
}

/// An iterator of attributes for a specific part of a function.
#[pyclass]
struct AttributeIterator {
iter: IntoIter<Py<Attribute>>,
}

#[pymethods]
impl AttributeIterator {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
// Returning `None` from `__next__` indicates that that there are no further items.
// and maps to StopIteration
fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<Py<Attribute>> {
slf.iter.next()
}
}

#[pymethods]
impl AttributeSet {
/// Tests if an attribute is a member of the set.
Expand Down Expand Up @@ -622,6 +640,23 @@ impl AttributeSet {
Ok(Attribute(attr))
}
}

fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<AttributeIterator>> {
let function = slf.function.borrow(slf.py()).into_super().into_super();

unsafe {
let attrs = qirlib::values::get_attributes(function.as_ptr(), slf.index);
let items = attrs
.into_iter()
.map(|a| Py::new(slf.py(), Attribute(a)).expect("msg"));
Py::new(
slf.py(),
AttributeIterator {
iter: items.collect::<Vec<Py<Attribute>>>().into_iter(),
},
)
}
}
}

#[derive(FromPyObject)]
Expand Down Expand Up @@ -863,12 +898,14 @@ pub(crate) fn extract_byte_string<'py>(py: Python<'py>, value: &Value) -> Option
// :param function: The function.
// :param kind: The attribute kind.
// :param value: The attribute value.
// :param index: The optional attribute index, defaults to the function index.
#[pyfunction]
#[pyo3(text_signature = "(function, key, value)")]
#[pyo3(text_signature = "(function, key, value, index)")]
pub(crate) fn add_string_attribute<'py>(
function: PyRef<Function>,
key: &'py PyString,
value: Option<&'py PyString>,
index: Option<u32>,
) {
let function = function.into_super().into_super().as_ptr();
let key = key.to_string_lossy();
Expand All @@ -881,6 +918,7 @@ pub(crate) fn add_string_attribute<'py>(
Some(ref x) => x.as_bytes(),
None => &[],
},
index.unwrap_or(LLVMAttributeFunctionIndex),
);
}
}
93 changes: 89 additions & 4 deletions pyqir/tests/test_string_attributes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from typing import List
import pyqir
from pyqir import (
Attribute,
ATTR_FUNCTION_INDEX,
ATTR_RETURN_INDEX,
Builder,
IntType,
ModuleFlagBehavior,
Module,
Context,
add_string_attribute,
Expand Down Expand Up @@ -41,9 +45,9 @@ def test_round_trip_serialize_parse() -> None:
function = Function(FunctionType(void, []), Linkage.EXTERNAL, "test_function", mod)
add_string_attribute(function, "foo", "bar")
# also test for non-value attributes
add_string_attribute(function, "entry_point", "")
add_string_attribute(function, "entry_point")
# test behavior of empty attribute
add_string_attribute(function, "", "")
add_string_attribute(function, "")
ir = str(mod)
parsed_mod = Module.from_ir(Context(), ir, "test")
assert str(parsed_mod) == str(mod)
Expand All @@ -54,7 +58,7 @@ def test_duplicate_attr_key_replaces_previous() -> None:
void = pyqir.Type.void(mod.context)
function = Function(FunctionType(void, []), Linkage.EXTERNAL, "test_function", mod)
add_string_attribute(function, "foo", "bar")
add_string_attribute(function, "foo", "")
add_string_attribute(function, "foo")
ir = str(mod)
# Tests that subsequently added attributes with the same key
# replace previously added ones
Expand All @@ -77,3 +81,84 @@ def test_attribute_alphabetical_sorting() -> None:
# Tests that attributes are sorted alphabetically by key,
# irrespective of their value
assert 'attributes #0 = { "1" "A"="123" "a"="a" "b"="A" "c" }' in ir


def test_function_attributes_can_be_iterated_in_alphabetical_order() -> None:
mod = pyqir.Module(pyqir.Context(), "test")
void = pyqir.Type.void(mod.context)
function = Function(FunctionType(void, []), Linkage.EXTERNAL, "test_function", mod)
# add them out of order, they will be sorted automatically
add_string_attribute(function, "required_num_results", "1")
add_string_attribute(function, "entry_point")
add_string_attribute(function, "required_num_qubits", "2")
attrs: List[Attribute] = list(function.attributes.func)
assert len(attrs) == 3
# Tests that attributes are sorted alphabetically by indexing into the list
assert attrs[0].string_kind == "entry_point"
assert attrs[0].string_value == ""
assert attrs[1].string_kind == "required_num_qubits"
assert attrs[1].string_value == "2"
assert attrs[2].string_kind == "required_num_results"
assert attrs[2].string_value == "1"


def test_parameter_attrs() -> None:
mod = pyqir.Module(pyqir.Context(), "test")
void = pyqir.Type.void(mod.context)
i8 = IntType(mod.context, 8)
function = Function(
FunctionType(void, [i8]), Linkage.EXTERNAL, "test_function", mod
)
# add them out of order, they will be sorted automatically
add_string_attribute(function, "zeroext", "", 1)
add_string_attribute(function, "mycustom", "myvalue", 1)

# params have their own AttributeSet
attrs = list(function.attributes.param(0))

attr = attrs[0]
assert attr.string_kind == "mycustom"
assert attr.string_value == "myvalue"

attr = attrs[1]
assert attr.string_kind == "zeroext"
assert attr.string_value == ""


def test_return_attrs_can_be_added_and_read() -> None:
mod = pyqir.Module(pyqir.Context(), "test")
void = pyqir.Type.void(mod.context)
i8 = IntType(mod.context, 8)
function = Function(
FunctionType(void, [i8]), Linkage.EXTERNAL, "test_function", mod
)
builder = Builder(mod.context)
builder.ret(None)

add_string_attribute(function, "mycustom", "myvalue", ATTR_RETURN_INDEX)

# params have their own AttributeSet
attrs = list(function.attributes.ret)

attr = attrs[0]
assert attr.string_kind == "mycustom"
assert attr.string_value == "myvalue"


def test_explicit_function_index_attrs_can_be_added_and_read() -> None:
mod = pyqir.Module(pyqir.Context(), "test")
void = pyqir.Type.void(mod.context)
i8 = IntType(mod.context, 8)
function = Function(
FunctionType(void, [i8]), Linkage.EXTERNAL, "test_function", mod
)
builder = Builder(mod.context)
builder.ret(None)

add_string_attribute(function, "mycustom", "myvalue", ATTR_FUNCTION_INDEX)

attrs = list(function.attributes.func)

attr = attrs[0]
assert attr.string_kind == "mycustom"
assert attr.string_value == "myvalue"
Loading

0 comments on commit 26885c9

Please sign in to comment.