Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #357

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Dev #357

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 70 additions & 7 deletions python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def safeunpickle(path):
class _call_no_record_scope:
def __enter__(self): pass
def __exit__(self, *exc): pass

def __call__(self, func):
def inner(*args, **kw):
with self:
Expand Down Expand Up @@ -135,6 +136,8 @@ class no_grad(flag_scope):
''' no_grad scope, all variable created inside this
scope will stop grad.

//lang[zh-cn] 没有梯度的区域,所有在这个区域内创建的变量都不再求梯度

Example::

import jittor as jt
Expand All @@ -147,10 +150,13 @@ def __init__(self, **jt_flags):
self.jt_flags = jt_flags
jt_flags["no_grad"] = 1


class enable_grad(flag_scope):
''' enable_grad scope, all variable created inside this
scope will start grad.

//lang[zh-cn] 赋予梯度的区域,所有在这个区域内创建的变量都会具有梯度

Example::

import jittor as jt
Expand All @@ -163,11 +169,13 @@ def __init__(self, **jt_flags):
self.jt_flags = jt_flags
jt_flags["no_grad"] = 0


single_log_capture = None


class log_capture_scope(_call_no_record_scope):
"""log capture scope

//lang[zh-cn]得到log的区域,用以判断其重要性并赋值
example::

with jt.log_capture_scope(log_v=0) as logs:
Expand Down Expand Up @@ -206,7 +214,7 @@ def __exit__(self, *exc):

class profile_scope(_call_no_record_scope):
""" profile scope

//lang[zh-cn] 写简介报告的区域
example::

with jt.profile_scope() as report:
Expand Down Expand Up @@ -264,7 +272,10 @@ def single_process_scope(rank=0):

All the mpi code inside this scope will have not affect.
mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1,

//lang[zh-cn]单一程序区域
//lang[zh-cn]这个区域的代码仅会被一个程序执行
//lang[zh-cn]这个区域内的mpi代码将不会产生影响
//lang[zh-cn]mpi.world_rank() 和 mpi.local_rank() 将返回 0, world_size() 将返回 1,
example::

@jt.single_process_scope(rank=0)
Expand Down Expand Up @@ -298,6 +309,11 @@ def array(data, dtype=None):
:param dtype: The data type of the Var. If None, the data type will be inferred from the data.
:type dtype: str, jittor type-cast function, or None.

//lang[zh-cn]:通过数字,列表(List),或其它计图变量来建立一个计图的变量
//lang[zh-cn]:数据参数:初始化变量的数值
//lang[zh-cn]:数据类型:具体数值,列表,numpy.ndarray类n维数组变量,或计图变量jittor.Var
//lang[zh-cn]:数据类型参数:变量的数据类型,若没有声明,则由数据推断得出
//lang[zh-cn]:描述数据类型的参数类型:字符串,计图数据类型转换功能,或者无
----------------

Example::
Expand Down Expand Up @@ -342,7 +358,13 @@ def random(shape, dtype="float32", type="uniform"):
:type dtype: str, jittor type-cast function, or None.
:param type: The random distribution, can be 'uniform' or 'normal'.
:type type: str

//lang[zh-cn]建立一个随机的计图变量
//lang[zh-cn]:程序结构:随机变量的结构
//lang[zh-cn]:结构数据类型:列表或元组
//lang[zh-cn]:程序数据类型:随机变量的数据类型
//lang[zh-cn]:描述相关数据类型的类型:字符串,计图数据类型转换功能或者无
//lang[zh-cn]:程序类型:随机分配,可以是一致的(uniform)或者正规的(normal)
//lang[zh-cn]:用以描述类型的类型
----------------

Example::
Expand Down Expand Up @@ -398,6 +420,13 @@ def ones(shape, dtype="float32"):
:type dtype: str, jittor type-cast function, or None.
:return: The output Var.
:rtype: jittor.Var
//lang[zh-cn]建立一个计图变量,且其所有的元素都置1
//lang[zh-cn]:参数结构:输出变量的结构
//lang[zh-cn]:结构数据类型:列表或元组
//lang[zh-cn]:程序的数据结构:输出变量的数据结构
//lang[zh-cn]:描述相关数据类型的类型:字符串,计图数据类型转换功能或者无
//lang[zh-cn]:返回值:输出变量
//lang[zh-cn]:返回数据类型:计图的变量(jittor.Var)
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
Expand All @@ -410,9 +439,16 @@ def ones_like(x):
:type x: jt.Var
:return: The output Var.
:rtype: jittor.Var
//lang[zh-cn]:建立一个所有元素置1的计图变量,且和参数x有相同的结构
//lang[zh-cn]:参数x:参考计图变量
//lang[zh-cn]:参数x类型:jt.Var的类型
//lang[zh-cn]:返回值:返回输出变量
//lang[zh-cn]:返回值类型:计图变量类型,jittor.Var

'''
return ones(x.shape,x.dtype)


def zeros(shape, dtype="float32"):
''' Constructs a jittor Var with all elements set to 0.

Expand All @@ -422,6 +458,12 @@ def zeros(shape, dtype="float32"):
:type dtype: str, jittor type-cast function, or None.
:return: The output Var.
:rtype: jittor.Var
//lang[zh-cn]:建立一个所有元素置0的计图变量
//lang[zh-cn]:结构数据类型:列表或元组
//lang[zh-cn]:程序的数据结构:输出变量的数据结构
//lang[zh-cn]:描述相关数据类型的类型:字符串,计图数据类型转换功能或者无
//lang[zh-cn]:返回值:返回输出变量
//lang[zh-cn]:返回值类型:计图变量类型,jittor.Var
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
Expand All @@ -437,35 +479,56 @@ def full(shape,val,dtype="float32"):
:param dtype: The data type of the output Var. Defaults to jt.float32.
:type dtype: str, jittor type-cast function, or None.
:return: The output Var.
:rtype: jittor.Var
:rtype: jittor.Var
//lang[zh-cn]:建立一个所有元素置为参数val的计图变量
//lang[zh-cn]:形状参数:输出变量的形状参数
//lang[zh-cn]:val参数:输出变量的值
//lang[zh-cn]:val参数类型:数值类型
//lang[zh-cn]:数据类型参数:输出变量的数据类型。默认为jt.float32类型
//lang[zh-cn]:描述相关数据类型的类型:字符串,计图数据类型转换功能或者无
//lang[zh-cn]:返回值:输出变量
//lang[zh-cn]:返回值类型:jittor.Var
'''
if not isinstance(shape, (NanoVector, Sequence)):
shape = (shape,)
return unary(val, dtype).broadcast(shape)

def full_like(x,val):

def full_like(x, val):
''' Constructs a jittor Var with all elements set to val and shape same with x.
:param x: The reference jittor Var.
:type x: jt.Var.
:param val: The value of the output Var.
:type val: number.
:return: The output Var.
:rtype: jittor.Var
//lang[zh-cn]:建立一个元素值都置val的计图变量且和参数x有相同的结构
//lang[zh-cn]:参数x:参考计图变量
//lang[zh-cn]:参数x种类:jt.Var
//lang[zh-cn]:val参数:输出参数值
//lang[zh-cn]:返回值:输出变量
//lang[zh-cn]:返回值类型:jittor.Var
'''
return full(x.shape,val,x.dtype)


def zeros_like(x):
''' Constructs a jittor Var with all elements set to 0 and shape same with x.

:param x: The reference jittor Var.
:type x: jt.Var
:return: The output Var.
:rtype: jittor.Var
//lang[zh-cn]:构建一个所有元素值置0的计图变量且与参数x有相同的结构
//lang[zh-cn]:参数x:jt.Var类型
//lang[zh-cn]:返回:输出变量
//lang[zh-cn]:返回值类型:jittor.Var
'''
return zeros(x.shape,x.dtype)

flags = core.Flags()


def var(x, dim=None, dims=None, unbiased=False, keepdims=False):
""" return the sample variance. If unbiased is True, Bessel's correction will be used.

Expand Down Expand Up @@ -711,7 +774,7 @@ def rand_like(x, dtype=None) -> Var:
[0.58626485 0.35345772 0.5638483 ]], dtype=float32)
'''
if dtype is None: dtype = x.dtype
return jt.random(x.shape, dtype)
return jt.random(x.shape, x.dtype)

def randn_like(x, dtype=None) -> Var:
''' samples random values from standard normal distribution with the same shape as x.
Expand Down
66 changes: 63 additions & 3 deletions python/jittor/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class Dataset(object):
'''
Base class for reading data.

//lang[zh-cn]用于读取数据的基础类
Args::

[in] batch_size(int): batch size, default 16.
Expand All @@ -89,7 +90,16 @@ class Dataset(object):
[in] buffer_size(int): buffer size for each worker in bytes, default(512MB).
[in] keep_numpy_array(bool): return numpy array rather than jittor array, default(False).
[in] endless(bool): will this dataset yield data forever, default(False).


//lang[zh-cn]参数::
//lang[zh-cn] [in] batch_size(int): 一次训练所抓取的数据样本数量,默认 16
//lang[zh-cn] [in] shuffle(bool): 是否打乱每次遍历的数据集,默认 否
//lang[zh-cn] [in] drop_last(bool): 若为真,则最后一组训练的数据样本数量可能小于batch_size,默认 真
//lang[zh-cn] [in] num_workers(int): 用于加载数据的工作线程数
//lang[zh-cn] [in] buffer_size(int): 每个工作线程的缓冲区大小(以字节为单位),默认值 512MB
//lang[zh-cn] [in] keep_numpy_array(bool): 是否返回 numpy 数组而非是否打乱每次遍历的数据集,默认不 jittor 数组,默认 否
//lang[zh-cn] [in] endless(bool): 是否持续在数据集中采样,默认否
Example::

class YourDataset(Dataset):
Expand Down Expand Up @@ -146,7 +156,7 @@ def __len__(self):
def set_attrs(self, **kw):
'''
You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size.

//lang[zh-cn] 您可以使用set_attrs函数设置数据集的属性,包括total_len, batch_size, shuffle, drop_last, num_workers, buffer_size参数。
Example::

dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
Expand All @@ -160,6 +170,15 @@ def set_attrs(self, **kw):
* num_workers: number of workers for loading data
* buffer_size: buffer size for each worker in bytes, default(512MB).
* stop_grad: stop grad for data, default(True).

//lang[zh-cn]参数:
//lang[zh-cn]* batch_size(int): 一次训练所抓取的数据样本数量,默认 16
//lang[zh-cn]* total_len(int): 总体长度
//lang[zh-cn]* shuffle(bool): 是否打乱每次遍历的数据集,默认 否
//lang[zh-cn]* drop_last(bool): 若为真,则最后一组训练的数据样本数量可能小于batch_size,默认 真
//lang[zh-cn]* num_workers: 用于加载数据的工作线程数
//lang[zh-cn]* buffer_size: 每个工作线程的缓冲区大小(以字节为单位),默认值 512MB
//lang[zh-cn]* stop_grad: 是否停止梯度值传播,默认是
'''
for k,v in kw.items():
assert hasattr(self, k), k
Expand All @@ -170,6 +189,7 @@ def set_attrs(self, **kw):
def to_jittor(self, batch):
'''
Change batch data to jittor array, such as np.ndarray, int, and float.
//lang[zh-cn]将处理数据转换为jittor数组,例如np.ndarray, int, and float.
'''
if self.keep_numpy_array: return batch
if isinstance(batch, jt.Var): return batch
Expand Down Expand Up @@ -200,12 +220,17 @@ def collate_batch(self, batch):

[in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on.

//lang[zh-cn]将数据集中的数据转换成统一可调度的批处理形式

//lang[zh-cn]参数::
//lang[zh-cn] [in] batch(list): 一个变量的list,例如jt.var, Image.Image, np.ndarray, int, float, str等等。
'''
return collate_batch(batch)

def terminate(self):
'''
Terminate is used to terminate multi-process worker reading data.
//lang[zh-cn]终止用于终止读取数据的多进程工作线程。
'''
if hasattr(self, "workers"):
for w in self.workers:
Expand All @@ -222,6 +247,9 @@ def _worker_main(self, worker_id, buffer, status):
# it is not work on ubuntu 16.04. but worked on ubuntu 20.04
# it seems like the static value of parallel compiler
# is not correctly init.
# //lang[zh-cn] parallel_op_compiler仍有问题,
# //lang[zh-cn] 它在ubuntu 16.04系统上不能工作但在ubuntu 20.04却可以
# //lang[zh-cn] 似乎并行编译器的静态值没有正确初始化
jt.flags.use_parallel_op_compiler = 0
import time
try:
Expand Down Expand Up @@ -250,6 +278,7 @@ def _worker_main(self, worker_id, buffer, status):
start = now

# load and transform data
# //lang[zh-cn] 加载并转换数据
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
Expand All @@ -261,6 +290,7 @@ def _worker_main(self, worker_id, buffer, status):
start = now

# send data to main process
# //lang[zh-cn] 将数据发送到主进程
if mp_log_v:
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
try:
Expand All @@ -285,8 +315,9 @@ def _worker_main(self, worker_id, buffer, status):
exit(0)

def display_worker_status(self):
''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow:
''' Display dataset worker status, when dataset.num_workers > 0, it will display information below:

//lang[zh-cn]显示数据集工作线程状态,当dataset.num_workers>0时,将显示如下信息
.. code-block:: console

progress:479/5005
Expand Down Expand Up @@ -320,6 +351,22 @@ def display_worker_status(self):
* load: worker load time
* buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes).

//lang[zh-cn]输出含义:

//lang[zh-cn]* progress: 加载数据集过程 (当前/整体)
//lang[zh-cn]* batch: 批处理时间,不包括数据加载时间
//lang[zh-cn]* wait: 主进程等待工作进程的时间
//lang[zh-cn]* recv: 采集批处理数据的时间
//lang[zh-cn]* to_jittor: 批处理数据为 jittor 变量的时间
//lang[zh-cn]* recv_raw_call: 被调用的基础recv_raw总数
//lang[zh-cn]* last 10 workers: 主要进程所加载的最后10个工作的id
//lang[zh-cn]* 表格含义

//lang[zh-cn]*ID: 工作的编号
//lang[zh-cn]* wait: 工作的等待时间
//lang[zh-cn]* open: 工作图像打开时间
//lang[zh-cn]* load: 工作加载时间
//lang[zh-cn]* buffer: 环缓冲区状态,例如有多少可用空间,左索引,右索引,总大小(字节)。
Example::

from jittor.dataset import Dataset
Expand Down Expand Up @@ -596,7 +643,6 @@ class ImageFolder(Dataset):
* root/label2/img1.png
* root/label2/img2.png
* ...

Args::

[in] root(string): Root directory path.
Expand All @@ -607,6 +653,19 @@ class ImageFolder(Dataset):
* class_to_idx(dict): map from class_name to class_index.
* imgs(list): List of (image_path, class_index) tuples

//lang[zh-cn] 图像分类数据集,从目录加载图像和标签::
//lang[zh-cn] * root/label1/img1.png
//lang[zh-cn] * root/label1/img2.png
//lang[zh-cn] * ...
//lang[zh-cn] * root/label2/img1.png
//lang[zh-cn] * root/label2/img2.png
//lang[zh-cn] * ...
//lang[zh-cn]参数::
//lang[zh-cn] [in] root(string):根目录路径。
//lang[zh-cn]属性::
//lang[zh-cn]* classes(list): 类名称的列表。
//lang[zh-cn]* class_to_idx(dict): 从class_name到class_index的地图。
//lang[zh-cn]* imgs(list): (image_path、class_index)元组列表
Example::

train_dir = './data/celebA_train'
Expand Down Expand Up @@ -644,6 +703,7 @@ def __getitem__(self, k):
class VarDataset(Dataset):
""" Dataset using Var directly, TensorDataset is alias of VarDataset, Example::

//lang[zh-cn] 数据集直接使用 Var,TensorDataset 是 VarDataset 的别名
import jittor as jt
from jittor.dataset import VarDataset

Expand Down