La costruzione del generatore richiede delle riflessioni un po' più
articolate rispetto a quelle che abbiamo adoperato per la creazione del discriminatore.
Anche il generatore sarà una rete neurale, ed in particolare vorremo far sì
che esso apprenda come "generare" un certo pattern di bit, nello
specifico il pattern 1010. Quello che ci auspichiamo è che
l'output del generatore riesca a passare l'attento controllo del
discriminatore. A tal fine, l'ultimo livello della rete del
generatore dovrà necessariamente essere composto da quattro nodi, lo stesso numero delle cifre del pattern che vogliamo generare.
Le domande che ci poniamo a questo punto sono:
-
quanti dovrebbero essere gli hidden layer della rete
neurale del generatore? -
a quanto dovrebbe corrispondere, numericamente, l'input di
ingresso della rete neurale?
Per rispondere alla prima domanda, possiamo dire che non esiste una
dimensione specifica per il numero di hidden layer della rete. Questi, in
linea di principio, dovrebbero essere in numero tale da permettere alla
rete di apprendere efficacemente senza rallentarne l'intero processo di
apprendimento.
Inoltre, l'obiettivo generale da tenere a mente rimane comunque quello di
cercare di eguagliare la velocità di apprendimento del discriminatore in
modo che uno non vada mai troppo avanti rispetto all'altro. Per queste
ragioni, molti sviluppatori invertono la struttura del Discriminatore per
la creazione di quella del generatore.
Proviamo quindi a sviluppare un generatore con un layer di input costituito da un
nodo, un hidden layer di tre nodi e un layer di output di quattro nodi, che non è
altro che la struttura speculare del discriminatore.
Il Generatore
Come tutte le reti neurali, anche quella appena descritta ha bisogno di un input. Per il momento optiamo per la scelta più semplice, utilizzando come input un valore costante. Poiché valori troppo grandi rendono
l'addestramento più difficile, usiamo inizialmente un valore di 0.5.
Per definire la classe Generator
possiamo copiare il codice dalla classe Discriminator
e apportarvi qualche modifica:
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1, 3),
nn.Sigmoid(),
nn.Linear(3, 4),
nn.Sigmoid()
)
self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
self.counter = 0
self.progress = []
pass
def forward(self, inputs):
return self.model(inputs)
La differenza principale tra il codice del discriminatore e quello del
generatore riguarda la definizione dei layer.
Nel codice del generatore non abbiamo infatti la funzione di loss che
abbiamo introdotto nel codice del discriminatore. Questo perché non ne
abbiamo bisogno all'interno del generatore. Se consideriamo il ciclo di
addestramento della GAN, vediamo che l'unica funzione di loss che usiamo è
quella che viene applicata alle uscite del discriminatore. Quando il
Generatore viene aggiornato, questo è basato sull'errore restituito da
questa funzione di loss del discriminatore.
Anche la funzione train()
del Generatore è un po' diversa rispetto a quella
del discriminatore. Mentre nel discriminatore siamo a conoscenza dell'output che
questo dovrebbe avere, nel caso del generatore non lo sappiamo. Le uniche
cose che siamo certi di possedere sono i gradienti retropropagati dalla
loss, e
calcolati dall'uscita del discriminatore nel ciclo di addestramento.
Quindi, per tale motivo, l'addestramento del generatore ha bisogno anche
del discriminatore. Ci sono diversi modi per codificare tutto quanto fin qui descritto: un modo semplice consiste nel passare il
discriminatore stesso alla funzione train()
del generatore. Questo permette di
mantenere l'addestramento chiaro, e il codice leggibile.
def train(self, D, inputs, targets):
g_output = self.forward(inputs)
d_output = D.forward(g_output)
loss = D.loss_function(d_output, targets)
self.counter += 1
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
Come si vede, gli input sono passati alla
rete neurale del generatore per
mezzo di self.forwards(inputs)
. L'output del generatore g_output
viene poi
passato alla rete neurale del discriminatore per mezzo di
D.forward(g_output)
.
La loss è calcolata tra d_output
e l'obiettivo desiderato. La
retropropagazione dei gradienti di errore inizia da questa loss, che
viene retropropagata al discriminatore e poi al generatore.
Gli aggiornamenti dei pesi avvengono usando self.optimiser
e non
D.optimiser
: in questo modo solo i pesi del generatore vengono aggiornati.
Testare il Generatore
A questo punto è bene verificare che i singoli elementi stiano funzionando a
dovere. Prima di usare il generatore nell'addestramento, controlliamo che
questo produca ciò che ci aspettiamo. Quindi, creiamo un nuovo oggetto
Generator
e passiamogli un valore pari a 0.5.
G = Generator()
G.forward(torch.FloatTensor([0.5]))
Eseguendo questo codice, noteremo che il tensore di uscita del
generatore contiene quattro valori: esattamente ciò che ci aspettavamo.
Il pattern restituito non è pari a 1010, perché il generatore
non è ancora stato addestrato.
Addestrare il Generatore
Vediamo infine come addestrare il generatore, servendoci del discriminatore e mettendo a frutto tutto quello che ci siamo detti nel corso di queste lezioni:
D = Discriminator()
G = Generator()
for i in range(10000):
D.train(generate_real(), torch.FloatTensor([1.0]))
D.train(G.forward(torch.FloatTensor([0.5])).detach(),torch.FloatTensor([0.0]))
G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
Per prima cosa creiamo gli oggetti Discriminator
e Generator
, ovvero D
e G
.
Ci serviremo di questi per avviare il nostro ciclo di addestramento, che durerà
10.000
iterazioni.
Il discriminatore viene addestrato sulla base
di dati reali. Successivamente addestriamo il discriminatore con un pattern
restituito dal generatore. La funzione detach()
è quindi
applicata all'output del generatore. La spiegazione di questo concetto è un
po' articolata: normalmente, l'esecuzione della funzione backwards()
sulla loss del
discriminatore fa sì che i gradienti della loss siano calcolati lungo tutto
il computation graph - dalla loss del discriminatore, attraverso il
discriminatore stesso e poi indietro attraverso il generatore; ma poiché
stiamo addestrando solo il discriminatore, non abbiamo bisogno di calcolare
i gradienti per il generatore stesso. Quel detach()
, quindi, applicato
all'output del generatore, taglia il computation graph in quel punto.
L'immagine seguente illustra questo concetto:
Conclusioni
In queste lezioni abbiamo creato una GAN in grado di generare un semplice pattern di bit. Chiaramente tutti questi concetti possono essere trasposti nel mondo delle immagini e per lavorare in scenari ben più complessi, apportando gli opportuni aggiustamenti nella definizione delle strutture di discriminatore e generatore, ed evidentemente ottimizzando in modo opportuni gli iperparametri che ne derivano.