Coverage for geometric_kernels / lab_extras / tensorflow / extras.py: 91%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 15:49 +0000

1from typing import TypeAlias 

2 

3import lab as B 

4import tensorflow as tf 

5import tensorflow_probability as tfp 

6from beartype.typing import Any 

7from lab import dispatch 

8 

9_Numeric: TypeAlias = B.Number | B.TFNumeric | B.NPNumeric 

10 

11 

12@dispatch 

13def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: # type: ignore 

14 """ 

15 Gathers elements of `a` along `axis` at `index` locations. 

16 """ 

17 return tf.experimental.numpy.take_along_axis(a, index, axis=axis) 

18 

19 

20@dispatch 

21def from_numpy(_: B.TFNumeric, b: list | B.Numeric | B.NPNumeric | B.TFNumeric): # type: ignore 

22 """ 

23 Converts the array `b` to a tensor of the same backend as `a` 

24 """ 

25 return tf.convert_to_tensor(b) 

26 

27 

28@dispatch 

29def trapz(y: _Numeric, x: _Numeric, dx=None, axis=-1): # type: ignore 

30 """ 

31 Integrate along the given axis using the composite trapezoidal rule. 

32 """ 

33 return tfp.math.trapz(y, x, dx, axis) 

34 

35 

36@dispatch 

37def norm(x: _Numeric, ord: Any | None = None, axis: int | None = None): # type: ignore 

38 """ 

39 Matrix or vector norm. 

40 """ 

41 return tf.norm(x, ord=ord, axis=axis) 

42 

43 

44@dispatch 

45def logspace(start: _Numeric, stop: _Numeric, num: int = 50, base: _Numeric = 50.0): # type: ignore 

46 """ 

47 Return numbers spaced evenly on a log scale. 

48 """ 

49 y = tf.linspace(start, stop, num) 

50 return tf.math.pow(base, y) 

51 

52 

53@dispatch 

54def degree(a: B.TFNumeric): # type: ignore 

55 """ 

56 Given an adjacency matrix `a`, return a diagonal matrix 

57 with the col-sums of `a` as main diagonal - this is the 

58 degree matrix representing the number of nodes each node 

59 is connected to. 

60 """ 

61 degrees = tf.reduce_sum(a, axis=0) # type: ignore 

62 return tf.linalg.diag(degrees) 

63 

64 

65@dispatch 

66def eigenpairs(L: B.TFNumeric, k: int): 

67 """ 

68 Obtain the eigenpairs that correspond to the `k` lowest eigenvalues 

69 of a symmetric positive semi-definite matrix `L`. 

70 """ 

71 l, u = tf.linalg.eigh(L) 

72 return l[:k], u[:, :k] 

73 

74 

75@dispatch 

76def set_value(a: B.TFNumeric, index: int, value: float): 

77 """ 

78 Set a[index] = value. 

79 This operation is not done in place and a new array is returned. 

80 """ 

81 a = tf.where(tf.range(len(a)) == index, value, a) 

82 return a 

83 

84 

85@dispatch 

86def dtype_double(reference: B.TFRandomState): # type: ignore 

87 """ 

88 Return `double` dtype of a backend based on the reference. 

89 """ 

90 return tf.float64 

91 

92 

93@dispatch 

94def float_like(reference: B.TFNumeric): 

95 """ 

96 Return the type of the reference if it is a floating point type. 

97 Otherwise return `double` dtype of a backend based on the reference. 

98 """ 

99 reference_dtype = reference.dtype 

100 if reference_dtype.is_floating: 

101 return reference_dtype 

102 else: 

103 return tf.float64 

104 

105 

106@dispatch 

107def dtype_integer(reference: B.TFRandomState): # type: ignore 

108 """ 

109 Return `int` dtype of a backend based on the reference. 

110 """ 

111 return tf.int32 

112 

113 

114@dispatch 

115def int_like(reference: B.TFNumeric): 

116 reference_dtype = reference.dtype 

117 if reference_dtype.is_integer: 

118 return reference_dtype 

119 else: 

120 return tf.int32 

121 

122 

123@dispatch 

124def get_random_state(key: B.TFRandomState): 

125 """ 

126 Return the random state of a random generator. 

127 

128 :param key: 

129 The random generator of type `B.TFRandomState`. 

130 """ 

131 return tf.identity(key.state), key.algorithm 

132 

133 

134@dispatch 

135def restore_random_state(key: B.TFRandomState, state): 

136 """ 

137 Set the random state of a random generator. Return the new random 

138 generator with state `state`. 

139 

140 :param key: 

141 The random generator. 

142 :param state: 

143 The new random state of the random generator of type `B.TFRandomState`. 

144 """ 

145 gen = tf.random.Generator.from_state(state=tf.identity(state[0]), alg=state[1]) 

146 return gen 

147 

148 

149@dispatch 

150def create_complex(real: _Numeric, imag: B.TFNumeric): 

151 """ 

152 Return a complex number with the given real and imaginary parts using tensorflow. 

153 

154 :param real: 

155 float, real part of the complex number. 

156 :param imag: 

157 float, imaginary part of the complex number. 

158 """ 

159 complex_num = tf.complex(B.cast(B.dtype(imag), from_numpy(imag, real)), imag) 

160 return complex_num 

161 

162 

163@dispatch 

164def complex_like(reference: B.TFNumeric): 

165 """ 

166 Return `complex` dtype of a backend based on the reference. 

167 """ 

168 return B.promote_dtypes(tf.complex64, reference.dtype) 

169 

170 

171@dispatch 

172def is_complex(reference: B.TFNumeric): 

173 """ 

174 Return True if reference of `complex` dtype. 

175 """ 

176 return (B.dtype(reference) == tf.complex64) or (B.dtype(reference) == tf.complex128) 

177 

178 

179@dispatch 

180def cumsum(x: B.TFNumeric, axis=None): 

181 """ 

182 Return cumulative sum (optionally along axis) 

183 """ 

184 return tf.math.cumsum(x, axis=axis) 

185 

186 

187@dispatch 

188def qr(x: B.TFNumeric, mode="reduced"): 

189 """ 

190 Return a QR decomposition of a matrix x. 

191 """ 

192 full_matrices = mode == "complete" 

193 Q, R = tf.linalg.qr(x, full_matrices=full_matrices) 

194 return Q, R 

195 

196 

197@dispatch 

198def slogdet(x: B.TFNumeric): 

199 """ 

200 Return the sign and log-determinant of a matrix x. 

201 """ 

202 sign, logdet = tf.linalg.slogdet(x) 

203 return sign, logdet 

204 

205 

206@dispatch 

207def eigvalsh(x: B.TFNumeric): 

208 """ 

209 Compute the eigenvalues of a Hermitian or real symmetric matrix x. 

210 """ 

211 return tf.linalg.eigvalsh(x) 

212 

213 

214@dispatch 

215def reciprocal_no_nan(x: B.TFNumeric): 

216 """ 

217 Return element-wise reciprocal (1/x). Whenever x = 0 puts 1/x = 0. 

218 """ 

219 return tf.math.reciprocal_no_nan(x) 

220 

221 

222@dispatch 

223def complex_conj(x: B.TFNumeric): 

224 """ 

225 Return complex conjugate 

226 """ 

227 return tf.math.conj(x) 

228 

229 

230@dispatch 

231def logical_xor(x1: B.TFNumeric, x2: B.TFNumeric): 

232 """ 

233 Return logical XOR of two arrays. 

234 """ 

235 return tf.math.logical_xor(x1, x2) 

236 

237 

238@dispatch 

239def count_nonzero(x: B.TFNumeric, axis=None): 

240 """ 

241 Count non-zero elements in an array. 

242 """ 

243 return tf.math.count_nonzero(x, axis=axis) 

244 

245 

246@dispatch 

247def dtype_bool(reference: B.TFRandomState): # type: ignore 

248 """ 

249 Return `bool` dtype of a backend based on the reference. 

250 """ 

251 return tf.bool 

252 

253 

254@dispatch 

255def bool_like(reference: B.NPNumeric): 

256 """ 

257 Return the type of the reference if it is of boolean type. 

258 Otherwise return `bool` dtype of a backend based on the reference. 

259 """ 

260 reference_dtype = reference.dtype 

261 if reference_dtype.is_bool: 

262 return reference_dtype 

263 else: 

264 return tf.bool