Coverage for geometric_kernels/lab_extras/tensorflow/extras.py: 91%

109 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-16 21:43 +0000

1import sys 

2 

3import lab as B 

4import tensorflow as tf 

5import tensorflow_probability as tfp 

6from beartype.typing import Any, List, Optional 

7from lab import dispatch 

8from plum import Union 

9 

10_Numeric = Union[B.Number, B.TFNumeric, B.NPNumeric] 

11 

12 

13@dispatch 

14def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore 

15 """ 

16 Gathers elements of `a` along `axis` at `index` locations. 

17 """ 

18 if sys.version_info[:2] <= (3, 9): 

19 index = tf.cast(index, tf.int32) 

20 return tf.experimental.numpy.take_along_axis( 

21 a, index, axis=axis 

22 ) # the absence of explicit cast to int64 causes an error for Python 3.9 and below 

23 

24 

25@dispatch 

26def from_numpy(_: B.TFNumeric, b: Union[List, B.Numeric, B.NPNumeric, B.TFNumeric]): # type: ignore 

27 """ 

28 Converts the array `b` to a tensor of the same backend as `a` 

29 """ 

30 return tf.convert_to_tensor(b) 

31 

32 

33@dispatch 

34def trapz(y: _Numeric, x: _Numeric, dx=None, axis=-1): # type: ignore 

35 """ 

36 Integrate along the given axis using the composite trapezoidal rule. 

37 """ 

38 return tfp.math.trapz(y, x, dx, axis) 

39 

40 

41@dispatch 

42def norm(x: _Numeric, ord: Optional[Any] = None, axis: Optional[int] = None): # type: ignore 

43 """ 

44 Matrix or vector norm. 

45 """ 

46 return tf.norm(x, ord=ord, axis=axis) 

47 

48 

49@dispatch 

50def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore 

51 """ 

52 Return numbers spaced evenly on a log scale. 

53 """ 

54 y = tf.linspace(start, stop, num) 

55 return tf.math.pow(base, y) 

56 

57 

58@dispatch 

59def degree(a: B.TFNumeric): # type: ignore 

60 """ 

61 Given an adjacency matrix `a`, return a diagonal matrix 

62 with the col-sums of `a` as main diagonal - this is the 

63 degree matrix representing the number of nodes each node 

64 is connected to. 

65 """ 

66 degrees = tf.reduce_sum(a, axis=0) # type: ignore 

67 return tf.linalg.diag(degrees) 

68 

69 

70@dispatch 

71def eigenpairs(L: B.TFNumeric, k: int): 

72 """ 

73 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

74 of a symmetric positive semi-definite matrix `L`. 

75 """ 

76 l, u = tf.linalg.eigh(L) 

77 return l[:k], u[:, :k] 

78 

79 

80@dispatch 

81def set_value(a: B.TFNumeric, index: int, value: float): 

82 """ 

83 Set a[index] = value. 

84 This operation is not done in place and a new array is returned. 

85 """ 

86 a = tf.where(tf.range(len(a)) == index, value, a) 

87 return a 

88 

89 

90@dispatch 

91def dtype_double(reference: B.TFRandomState): # type: ignore 

92 """ 

93 Return `double` dtype of a backend based on the reference. 

94 """ 

95 return tf.float64 

96 

97 

98@dispatch 

99def float_like(reference: B.TFNumeric): 

100 """ 

101 Return the type of the reference if it is a floating point type. 

102 Otherwise return `double` dtype of a backend based on the reference. 

103 """ 

104 reference_dtype = reference.dtype 

105 if reference_dtype.is_floating: 

106 return reference_dtype 

107 else: 

108 return tf.float64 

109 

110 

111@dispatch 

112def dtype_integer(reference: B.TFRandomState): # type: ignore 

113 """ 

114 Return `int` dtype of a backend based on the reference. 

115 """ 

116 return tf.int32 

117 

118 

119@dispatch 

120def int_like(reference: B.TFNumeric): 

121 reference_dtype = reference.dtype 

122 if reference_dtype.is_integer: 

123 return reference_dtype 

124 else: 

125 return tf.int32 

126 

127 

128@dispatch 

129def get_random_state(key: B.TFRandomState): 

130 """ 

131 Return the random state of a random generator. 

132 

133 :param key: 

134 The random generator of type `B.TFRandomState`. 

135 """ 

136 return tf.identity(key.state), key.algorithm 

137 

138 

139@dispatch 

140def restore_random_state(key: B.TFRandomState, state): 

141 """ 

142 Set the random state of a random generator. Return the new random 

143 generator with state `state`. 

144 

145 :param key: 

146 The random generator. 

147 :param state: 

148 The new random state of the random generator of type `B.TFRandomState`. 

149 """ 

150 gen = tf.random.Generator.from_state(state=tf.identity(state[0]), alg=state[1]) 

151 return gen 

152 

153 

154@dispatch 

155def create_complex(real: _Numeric, imag: B.TFNumeric): 

156 """ 

157 Return a complex number with the given real and imaginary parts using tensorflow. 

158 

159 :param real: 

160 float, real part of the complex number. 

161 :param imag: 

162 float, imaginary part of the complex number. 

163 """ 

164 complex_num = tf.complex(B.cast(B.dtype(imag), from_numpy(imag, real)), imag) 

165 return complex_num 

166 

167 

168@dispatch 

169def complex_like(reference: B.TFNumeric): 

170 """ 

171 Return `complex` dtype of a backend based on the reference. 

172 """ 

173 return B.promote_dtypes(tf.complex64, reference.dtype) 

174 

175 

176@dispatch 

177def is_complex(reference: B.TFNumeric): 

178 """ 

179 Return True if reference of `complex` dtype. 

180 """ 

181 return (B.dtype(reference) == tf.complex64) or (B.dtype(reference) == tf.complex128) 

182 

183 

184@dispatch 

185def cumsum(x: B.TFNumeric, axis=None): 

186 """ 

187 Return cumulative sum (optionally along axis) 

188 """ 

189 return tf.math.cumsum(x, axis=axis) 

190 

191 

192@dispatch 

193def qr(x: B.TFNumeric, mode="reduced"): 

194 """ 

195 Return a QR decomposition of a matrix x. 

196 """ 

197 full_matrices = mode == "complete" 

198 Q, R = tf.linalg.qr(x, full_matrices=full_matrices) 

199 return Q, R 

200 

201 

202@dispatch 

203def slogdet(x: B.TFNumeric): 

204 """ 

205 Return the sign and log-determinant of a matrix x. 

206 """ 

207 sign, logdet = tf.linalg.slogdet(x) 

208 return sign, logdet 

209 

210 

211@dispatch 

212def eigvalsh(x: B.TFNumeric): 

213 """ 

214 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

215 """ 

216 return tf.linalg.eigvalsh(x) 

217 

218 

219@dispatch 

220def reciprocal_no_nan(x: B.TFNumeric): 

221 """ 

222 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

223 """ 

224 return tf.math.reciprocal_no_nan(x) 

225 

226 

227@dispatch 

228def complex_conj(x: B.TFNumeric): 

229 """ 

230 Return complex conjugate 

231 """ 

232 return tf.math.conj(x) 

233 

234 

235@dispatch 

236def logical_xor(x1: B.TFNumeric, x2: B.TFNumeric): 

237 """ 

238 Return logical XOR of two arrays. 

239 """ 

240 return tf.math.logical_xor(x1, x2) 

241 

242 

243@dispatch 

244def count_nonzero(x: B.TFNumeric, axis=None): 

245 """ 

246 Count non-zero elements in an array. 

247 """ 

248 return tf.math.count_nonzero(x, axis=axis) 

249 

250 

251@dispatch 

252def dtype_bool(reference: B.TFRandomState): # type: ignore 

253 """ 

254 Return `bool` dtype of a backend based on the reference. 

255 """ 

256 return tf.bool 

257 

258 

259@dispatch 

260def bool_like(reference: B.NPNumeric): 

261 """ 

262 Return the type of the reference if it is of boolean type. 

263 Otherwise return `bool` dtype of a backend based on the reference. 

264 """ 

265 reference_dtype = reference.dtype 

266 if reference_dtype.is_bool: 

267 return reference_dtype 

268 else: 

269 return tf.bool