-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlut_util.py
369 lines (318 loc) · 13.4 KB
/
lut_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
#!/usr/bin/env python3
#
# Copyright (C) 2018 Troy Sankey
# This file is released under the GNU GPL, version 3 or a later revision.
# For further details see the COPYING file
#
# References:
# - Hald CLUT reference: http://www.quelsolaar.com/technology/clut.html
# - 3D LUT (3DL) reference: http://download.autodesk.com/us/systemdocs/pdf/lustre_color_management_user_guide.pdf#page=14
import math
import numbers
import sys
from array import array
from decimal import Decimal
from pathlib import Path
import png
from PIL import Image
def is_perfect_six_root(n):
c = int(n ** (1 / 6.))
return (c ** 6 == n) or ((c + 1) ** 6 == n)
def uniform_intervals(end, samples, floating_point=False):
"""
Make `samples` uniformly distributed numbers from 0 to `end`.
"""
dist = end / float(samples - 1)
values = [dist * i for i in range(samples)]
if not floating_point:
values = [int(round(v)) for v in values]
for idx in range(1, samples):
actual_dist = values[idx] - values[idx - 1]
error_frac = abs(float(actual_dist) / dist - 1.0)
if error_frac > 0.07:
raise ValueError('input parameters to uniform_intervals would yield a non-uniform distribution.')
return values
class Value3D(object):
def __init__(self, components):
self.components = tuple(components)
def __str__(self):
return 'Value3D({},{},{})'.format(*self.components)
def __bytes__(self):
return self.__str__()
def __repr__(self):
return self.__str__()
def __iter__(self):
return iter(self.components)
def __add__(self, y):
return Value3D(
(
self.components[0] + y.components[0],
self.components[1] + y.components[1],
self.components[2] + y.components[2],
)
)
def __mul__(self, y):
return Value3D(
(
self.components[0] * y,
self.components[1] * y,
self.components[2] * y,
)
)
def __rmul__(self, x):
return self.__mul__(x)
def index_3d(data, size, r_idx, g_idx, b_idx):
"""
Index a flattened 3-channel 3D cubic matrix.
"""
idx = (r_idx) + (size * g_idx) + (size ** 2 * b_idx)
idx *= 3
return data[idx:idx + 3]
class ColorLUT(object):
"""
"""
def __init__(self, data, sample_count=None, input_domain=None, red_increments_fastest=True):
if not isinstance(data, array) or not isinstance(data[0], numbers.Number):
raise ValueError('data parameter should be a flat list of numbers.')
if not isinstance(sample_count, int):
raise ValueError('sample_count parameter should be of type int.')
if not isinstance(input_domain, numbers.Number):
raise ValueError('input_domain parameter should be a number.')
if isinstance(input_domain, numbers.Integral):
if data.typecode not in 'bBhHiIlL':
raise ValueError('input_domain parameter should have the same type as the data.')
else:
if data.typecode not in 'fd':
raise ValueError('input_domain parameter should have the same type as the data.')
if not len(data) == 3 * (sample_count ** 3):
raise ValueError('The sample intervals do not appear to match the matrix dimensions.')
self.data = data
self.sample_count = sample_count
self.input_domain = input_domain
self.red_increments_fastest = red_increments_fastest
if data.typecode in 'fd':
self.datatype = numbers.Real
elif data.typecode in 'bBhHiIlL':
self.datatype = numbers.Integral
self.sample_distance = self.input_domain / float(self.sample_count - 1)
def get_color_value_from_index(self, r_idx, g_idx, b_idx):
"""
Determine the output color value given 3D matrix indices.
"""
if not self.red_increments_fastest:
r_idx, b_idx = b_idx, r_idx
color_value = Value3D(index_3d(self.data, self.sample_count, r_idx, g_idx, b_idx))
return color_value
def get_interpolated_color_value(self, r_input, g_input, b_input):
"""
Determine the output color value using trilinear interpolation.
Algorithm adapted from https://en.wikipedia.org/wiki/Trilinear_interpolation
"""
# On wikipedia, the equations for v_d were:
#
# r_d = ( r - r_0 ) / ( r_1 - r_0 )
# g_d = ( g - g_0 ) / ( g_1 - g_0 )
# b_d = ( b - b_0 ) / ( b_1 - b_0 )
#
# but v-v_0 is equivalent to math.remainder(v, self.sample_distance),
# and v_1-v_0 is equivalent to self.sample_distance,
# therefore, v_d = float(Decimal(v_input) % Decimal(self.sample_distance)) / self.sample_distance.
#
# Furthermore, we need to handle the border case where v_input == the
# maximum possible value (i.e. self.input_domain).
if r_input == self.input_domain:
r_0_idx = self.sample_count - 2
r_d = 1
else:
r_0_idx = math.trunc(r_input / self.sample_distance)
r_d = float(Decimal(r_input) % Decimal(self.sample_distance)) / self.sample_distance
if g_input == self.input_domain:
g_0_idx = self.sample_count - 2
g_d = 1
else:
g_0_idx = math.trunc(g_input / self.sample_distance)
g_d = float(Decimal(g_input) % Decimal(self.sample_distance)) / self.sample_distance
if b_input == self.input_domain:
b_0_idx = self.sample_count - 2
b_d = 1
else:
b_0_idx = math.trunc(b_input / self.sample_distance)
b_d = float(Decimal(b_input) % Decimal(self.sample_distance)) / self.sample_distance
r_1_idx = r_0_idx + 1
g_1_idx = g_0_idx + 1
b_1_idx = b_0_idx + 1
c_000 = self.get_color_value_from_index(r_0_idx, g_0_idx, b_0_idx)
c_001 = self.get_color_value_from_index(r_0_idx, g_0_idx, b_1_idx)
c_010 = self.get_color_value_from_index(r_0_idx, g_1_idx, b_0_idx)
c_011 = self.get_color_value_from_index(r_0_idx, g_1_idx, b_1_idx)
c_100 = self.get_color_value_from_index(r_1_idx, g_0_idx, b_0_idx)
c_101 = self.get_color_value_from_index(r_1_idx, g_0_idx, b_1_idx)
c_110 = self.get_color_value_from_index(r_1_idx, g_1_idx, b_0_idx)
c_111 = self.get_color_value_from_index(r_1_idx, g_1_idx, b_1_idx)
c_00 = c_000 * (1.0 - r_d) + c_100 * r_d
c_01 = c_001 * (1.0 - r_d) + c_101 * r_d
c_10 = c_010 * (1.0 - r_d) + c_110 * r_d
c_11 = c_011 * (1.0 - r_d) + c_111 * r_d
c_0 = c_00 * (1.0 - g_d) + c_10 * g_d
c_1 = c_01 * (1.0 - g_d) + c_11 * g_d
c = c_0 * (1.0 - b_d) + c_1 * b_d
return c
def get_values_translated(
self,
increment_red_fastest=True,
output_sample_count=None,
output_domain=None,
):
"""
Make an iterable of output color values in sequence.
If necessary, reorder the output data values in order to make them
correspond to red/blue input channels incrementing most/least rapidly
by default. Switch increment_red_fastest=False for the opposite
behavior
"""
interpolate_output = output_sample_count != self.sample_count
scale_output = output_domain != self.input_domain
scaling_factor = output_domain / float(self.input_domain)
if increment_red_fastest:
indexes = (
(r, g, b)
for b in range(output_sample_count)
for g in range(output_sample_count)
for r in range(output_sample_count)
)
else:
indexes = (
(r, g, b)
for r in range(output_sample_count)
for g in range(output_sample_count)
for b in range(output_sample_count)
)
if interpolate_output:
input_values = (
Value3D(idx) * (self.input_domain / float(output_sample_count - 1))
for idx in indexes
)
output_values = (
self.get_interpolated_color_value(*input_value)
for input_value in input_values
)
else:
output_values = (
self.get_color_value_from_index(*idx)
for idx in indexes
)
if scale_output:
output_values = (
output_value * scaling_factor
for output_value in output_values
)
return output_values
@classmethod
def from_haldclut(cls, src):
src_png = png.Reader(filename=src)
width, height, data, meta = src_png.read_flat()
if 'palette' in meta:
raise ValueError('Then given PNG file uses a color palette. Refusing.')
if 'gamma' in meta:
raise ValueError('Then given PNG file contains a gamma value. Refusing.')
if 'transparent' in meta:
raise ValueError('Then given PNG file specifies a transparent color. Refusing.')
if meta['alpha']:
raise ValueError('Then given PNG file contains an alpha channel. Refusing.')
if meta['greyscale']:
raise ValueError('Then given PNG file is greyscale. Refusing.')
if meta['bitdepth'] not in (8, 16):
raise ValueError('Then given PNG file specifies an unsupported bit depth. Refusing.')
width_is_square_root_of_perfect_six_root = is_perfect_six_root(width ** 2)
if width != height or not width_is_square_root_of_perfect_six_root:
raise ValueError('The given PNG file does not have appropriate Hald CLUT dimensions. Refusing.')
sample_count = int(round((width ** 2) ** (1. / 3)))
input_domain = 2 ** meta['bitdepth'] - 1
return cls(data, sample_count=sample_count, input_domain=input_domain)
@classmethod
def from_3dl(cls, src):
raise NotImplementedError()
def write_haldclut(self):
raise NotImplementedError()
def write_3dl(self, dest):
output_domain = 1023
sample_intervals = uniform_intervals(output_domain, self.sample_count)
color_value_gen = self.get_values_translated(
increment_red_fastest=False,
output_sample_count=self.sample_count,
output_domain=output_domain,
)
with open(dest, 'w') as destfile:
destfile.write(' '.join(str(v) for v in sample_intervals))
destfile.write('\n')
for color in color_value_gen:
line = ' '.join('{:.0f}'.format(v) for v in color)
destfile.write(line)
destfile.write('\n')
def write_cube(self, dest):
output_domain = 1.0
output_sample_count = self.sample_count
color_value_gen = self.get_values_translated(
output_sample_count=output_sample_count,
output_domain=output_domain,
)
with open(dest, 'w') as destfile:
destfile.write('LUT_3D_SIZE {}'.format(output_sample_count))
destfile.write('\n')
for color in color_value_gen:
line = ' '.join('{:.7g}'.format(v) for v in color)
destfile.write(line)
destfile.write('\n')
def alternative_hald_to_3dl(image_path: Path):
print("Warning: This script does not return accurate results.", file=sys.stderr)
in_ = Image.open(image_path)
w, h = in_.size
if w != h:
print('HALD input is not square.', file=sys.stderr)
exit(2)
steps = int(round(math.pow(w, 1 / 3)))
if steps ** 3 != w:
print('HALD input size is invalid: %d is not a cube.' % w, file=sys.stderr)
print('%d steps' % steps, file=sys.stderr)
# Assume that we are going from 8 bits to 10.
out = open(f'{image_path.stem}_converted_alternative.3dl', 'w')
header = [1023 * i // (steps - 1) for i in range(steps)]
out.write(' '.join(str(x) for x in header))
out.write('\n')
steps1 = steps + 1
steps3 = steps ** 2 * (steps + 1)
steps5 = steps ** 4 * (steps + 1)
data = list(in_.getdata())
def lookup(ri, gi, bi):
return data[
ri * steps1 + gi * steps3 + bi * steps5
]
for ri in range(steps):
for gi in range(steps):
for bi in range(steps):
r, g, b = lookup(ri, gi, bi)
out.write('%d %d %d\n' % (r * 4, g * 4, b * 4))
if __name__ == '__main__':
image_path = Path('Neutral_25_converted.png')
algorithm = 1
if algorithm == 0:
dest_type = '3dl'
src_ext = 'png'
if src_ext == 'png':
clut = ColorLUT.from_haldclut(image_path)
elif src_ext == '3dl':
clut = ColorLUT.from_3dl(image_path)
elif src_ext == 'cube':
clut = ColorLUT.from_haldclut(image_path)
else:
raise ValueError('Not an appropriate Color LUT file type: {}'.format(src_ext))
if dest_type == 'haldclut':
clut.write_haldclut()
elif dest_type == '3dl':
clut.write_3dl(f'{image_path.stem}_converted.3dl')
elif dest_type == 'cube':
clut.write_cube(f'{image_path.stem}_converted.cube')
else:
raise ValueError('Not an appropriate Color LUT file type: {}'.format(dest_type))
else:
alternative_hald_to_3dl(image_path)