Neural Cryptography- Stop me before I SIGKILL again

A plea for help with the only thing that transformers can't seem to do.

The Problem

The year is 2020 and the proliferation of Big Old Language Models is still a faraway dream. I am staring at code from a client:Code has been changed to protect the innocent.

#gotta make a token and send it to the client!
very_random_number = get_insecure_random_number()
# e.g. (.0324191942[...])=>"ABXDC"
two_factor_token = convert_representation(very_random_number)
send_email("Your two factor authentication token is:"
+two_factor_token,user_email)
save_token_to_user(user_id,two_factor_token)

I have seen this code and many variations of it. There are many services in your life which strongly resemble it. Is it secure?

There are good reasons to say "no, obviously not" to the above question:

But there is one big problem: What is the POC? If I have no POC, must I not GTFO? There are theoretical weaknesses, but actually predicting tokens based on previous tokens has not been done. It might be possible with a bespoke SAT solver solution, but that's certainly beyond my ability. And that still depends on an attacker having access to the source code. Should this be fixed immediately? Or merely noted as a low-severity finding An industry term that means 'never fixed'

This underlying conundrum was the subject of my last article on this subject. It attacks XORSHIFT128 with the goal of eventually attacking XORSHIFT128+. That would be this algorithm, which is used in Chrome and thus probably the most common RNG:

def xorshift128plus(x, y):
	s0, s1 = y, x
	s1 ^= (s1 << 23) & MAXSIZE
	s1 ^= (s1 >> 17)
	s1 ^= s0
	s1 ^= (s0 >> 26)
	x = y
	y = s1
	generated = (x + y) & MAXSIZE
	return generated, x, y,

The techniques from that article are primitive in the cold light of 2024; LSTMs were already being replaced by Transformers, and at this point they have been eclipsed entirely. But now I'm wiser, I have a faster computer, and more importantly there have been approximately ten million new advancements in machine learning. How much harder could it be?

Much harder, as it turns out.

The graveyard of failed technique

Several years of tinkering in my free time have resulted in no success in this field of problems. The original problem is this:

Given the previous N outputs of xorshift128+, predict the next output of xorshift128+

It is a straightforward problem; given that it is solvable with an SAT solver, it should be possible enough with a neural network, our beloved universal function approximator It probably is possible to approximate this function, but it would also be convenient if it was possible to learn this representation. . But no amount of compute or hyperparameter tweaking succeeded in any way — Not running at home on my beefy M1, not on rented GPU time from AWS. There is a wealth of literature about making your network stop overfitting- but seemingly none about what to do when your network doesn't fit at all.

With this in mind, I tried to narrow down my dream somewhat:

I have learned much about the innards of neural network design in the interim, but none of it has actually solved this problem. I have tried every kind of transformer stack that fits on my machine. Given that the input is ideally 128 bits and certainly no larger than 256 bits, you can get a pretty large network into an M1 with 64 GB of ram. But to emphasize: no combination of architectures, regularization techniques, or other tricks seems to get the job done. This went up to networks that hit the O(N²) memory bottleneck for transformers It's not particularly difficult to generate millions of test cases, so overfitting is the least of my problems. All of them dutifully run for hours or days, stuck at mean error which is no better than chance before I terminate them.

So, in the face of all reason, three and a half years of tinkering has produced more or less nothing. It is difficult to say why, but there are several compelling reasons:

Having exhausted the range of general neural networks, I decided to commit the unforgivable sin of designing network architecture specifically to the problem at hand. If I hand-designed a network to perform the forward pass from state to output, could I use to calculate output to state? It would not be a complete solution, but it would be progress.

Gradient Descent Into Madness

Conventional machine learning follows a tried-and-true pattern. We define a function which is itself defined in terms of its parameters: each node in a densely connected network has a large pile of weights to determine its output values. When combined together, the functions that this network can approximate a surprising number of output functions. It can also approximate a lot of useless functions, so training is needed: we calculate the function's error given some weights and then change the weights slightly to reduce that error across the training examples.

In this case, however, I would like to calculate the inputs given the outputs. Thus, my idea was to perform gradient descent on the inputs, given the correct answer and outputs. This means the following:

I would like to do this in a way that is as general as possible, while acknowledging that the current state-of-the-art is too general. XORSHIFT128+ is built from the following components, which constitute the majority of RNG operations:There are just not that many bitwise operations that don't leak entropy, so the list has to be pretty short here.

However, we can't just use the built-in pytorch functions like xor and roll for these as they are not smoothly differentiable. Performing gradient descent is only possible if every operation has a gradient! If we want to perform gradient descent on random input, it also needs to have a meaningful gradient across the entire domain- from [0,1). So it is written, so it shall be done.

XOR

XOR's inputs for a given bit can have one of two values, which (when extended to real input) means anything between 0 and 1. To make the math easier we will actually have these be to -1 and 1 inside of nodes It makes it possible to do the matrix multiplication trick at the bottom. . Let's start by trying to "push" these values towards the edges. I accomplish this with the 'certainty' parameter, which indicates how much these should be flattened into one of the two above. There are many real valued functions, but I use:

xx = (x -0.5) * 2
inputs = torch.tanh(xx*self.certainty) 

which takes an input vector x and subsequently turns it into something much closer to one of the above. Higher certainty means more 'pushing' into these two values, but also makes the function's gradients much more jagged. As we'll come to see this matters a lot - too much certainty, and the gradients are impossible to calculate — not enough certainty, and error propagates through the operations and befouls the answer.

Because each bit can have one of two values, we can have four possible outputs that have meaning: [[-1,-1],[-1,1],[1,-1],[1,1]]. Calculating the similarity between a given input and these four outputs is quite simple: we multiply the above 4x2 matrix by a 2x1 matrix with the input and receive four numbers, each of which will be higher in the case that they match more closely. Our old friend softmax comes to the rescue to convert this mix of positive and negative numbers into percentages of certainty. I also use certainty here to make the calculations favor one position over the others:

target= torch.tensor([[-1,-1],[-1,1],[1,-1],[1,1]]).float().T
prod = torch.matmul(inputs,target)
best_name = torch.softmax(prod*self.certainty,dim=2) 

The resulting 4x1 matrix needs to then be converted into the correct output, which I have defined as a 1x4 lookup table. For example:

[.01,.01,.97,.01] 

Can be multiplied by the XOR lookup table:

[0,1,1,0] //transposed 

producing a value very close to 1. There are some reasons to believe that this isn't the best solution for the problem at hand, but it works and is simple for each bit. We then simply perform this operation column-wise on a set of input vectors, letting us perform half of operations like x ^= x << 23; that are ubiquitous in all cryptography With a little static single assignment to avoid these in-place shifts.

Bit Shifting

The other half of these involve bit shifting, which is itself a pretty simple operation. Given a list of bits [0,1,1,0...], a bit width, and an integer N, we want to shift them to the left or the right. So an 8-bit integer:

	00001111
	
shifted 2 to the left becomes:
	00111100
	
or shifted 2 to the right becomes:
	00000011
	

Any positions not filled previously become 0; we don't wrap them around. Torch provides torch.roll which mimics this functionality with wrapping: we would get 11000011 on the shift right, which is not what we want. The fix for this is pretty trivial, but it is not smoothly differentiable, so it is dead to me. I need a smooth shift for my smooth brain.

Given a vector of length N, we can get an identity transpose by putting 1s along the top-left to bottom-right diagonal. Multiplying this by the vector changes nothing. For example:

But we can subtly alter this matrix to perform more exciting and dynamic operations: changing the position of the 1 in a given row to the 1 you'd like to end up in allows multiplying an input vector to transpose its output. For example:

Thus, by shifting each column to the left or the right and truncating the remaining zeroes will let us create a matrix that shifts by the appropriate amount. How then to turn this into something differentiable? I define a tensor with all the possible useful shifts: for a bit with of N, we can shift between 1-N (all the way to the right) and (N-1) (all the way to the left):

def bit_shift_matricies(bits):
    eyes = torch.stack(
		[torch.roll(torch.eye(bits), i, 0) for i in brange(bits)]
		)
    for i in range(2 * bits - 1):
		b = bit_to_minus(i,bits)
		if b<0:
			eyes[i][:][b:] = 0.0
		elif b>=0:
			eyes[i][:][:b] = 0.0
		return eyes

Given this, a parameter between 1-N and N-1 can be used to return a weighted average of the sum of these matrices. That is, we extend the shift function by having shift left by N and shift left by N-1 defined by taking the average of its two outputs, at the bitwise level.

def differentiable_shift(x,n,eyes,certainty,bits):
    shift = weighed_smooth_vector(
		torch.arange(1-bits,bits,1),n,1.0/certainty
		)
    mults = torch.sum(shift[:,None,None]*eyes,0)
    return torch.matmul(x,mults)
            

The shift vector needs to produce a probability distribution that assigns the highest value to n, which represent the current shift amount. A gaussian distribution which is normalized to the legal values is used here, and then each matrix's elements are multiplied by that value to produce the right one. With a reasonably high value for certainty (~6) this will become 1 for the given value and 0 for all others while also being differentiable But not very differentiable. The gradients seem to be very small.

Adding with carry is the last required operation: given a pair of numbers expressed bitwise, we would like to produce the bitwise output in a smoothly differentiable way. We use a similar technique as the xor gates: each combination of bitwise input is matched via matrix multiplication to one of the eight combinations of possible inputs, which then have a fixed table of outputsAt the hardware level, add-with-carry is carried out with logic gates, so it shouldn't be too shocking to see them, here. .

def forward(self, x):
    x = 2*(x - 0.5) //moves 0,1 to the -1, +1 domain
	//multiplies a 1x3 matrix by an 3x8 matrix to produce a 1x8 matrix)
    x = torch.matmul(x, self.truthTableInputs)

	//subsequently turns this 1x8 matrix into a propability distribution
    x = torch.softmax(x* self.certainty, dim=1)

	//then multiplies it by the correct outputs for each combination
	//to produce an output and the carry bit for the next column
    x = torch.matmul(x, self.truthTableOutputs)
    return x
				

Moving from left to right, this component of the network produces a row of outputs, and a row of carry bits, which are used to perform bitwise arithmetic in the same way that a computer does. The associated network simply performs these gates in the right order.

To come so far, only to fail

With these three pieces in hand, we can calculate the correct output quite easily, by chaining them together and verifying them. The repository's provided verify_model_works.py randomly selects several thousand inputs and outputs and verifies that the output of the torch version matches the RNG itself, both for the output and the state. Varying the certainty parameters through the model can make it more differentiable, but shows that there is eventually error introduced into the calculations — the values I've chosen were the smallest whole numbers I could get to work.

Unfortunately, this doesn't work for the proposed task. The associated perform_descent.py code is currently configured to perform the task of performing gradient descent here. There are several interesting questions which I invite you to play with: if almost all of the bits are already correct, will the model converge? Will the model converge if all the inputs are set randomly, or all very close to 0.5? I was surprised to discover that even in the case where all but one bit in the input is set to its correct value, the model does not always converge to the correct answer.

Thus, what I need is help. If you are purely curious about the topic, that's good enough for me, but there are practical reasons as well. Many applications I have seen rely on these types of insecure RNGs, shielded only by the obfuscation of their source code. I also don't see any theoretical reason to see why a CSPRNG would be different—though obviously I would really like to find one. withstand attacks in such a way. Easy to say after not discovering anything that works, but the individual operations are identical.

I have tried a lot of things that have not worked, but there are several reasonable ideas:

Although it is painful to publish a negative result, the alternative is losing my mind. I have many more ideas to try, but I am not sure that they will work and I don't know anyone else working on the topic. There are several papers in the area available online, but almost all of them are.. sketchy, with little genuinely interesting output and some purporting to do things that are seemingly impossibleMy favorite by far is this, which somehow purports to learn digits of π and e from supervised learning, claiming statistical significance through an incredible abuse of statistics.. There is a dearth of good research despite the importance of any successful work in the area, and the relative ease of experimentation alone. All the code here runs easily on my macbook without really taxing the machine at all. It is a curious gap.

The provided repository has been winnowed down from my previous work on the subject and provides what I hope is a straightforward example on trying to attack the problem. If you have any information or suggestions that I might try, I will happily implement them and report back if they seem interesting, and I have not thought of them yet. I am particularly interested in the expertise of someone who has worked on ML research or cryptography. I am also seeking employment in this area— If you would like me to grind my face against this or another related problem, for money, I would be thrilled to do so.

Please send me a message if you have anything you think would be helpful.